Compare commits
2 Commits
feat/devex
...
optional-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a477118337 | ||
|
|
99d9ea1464 |
@@ -1,18 +0,0 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.{yml,yaml,json,toml}]
|
||||
indent_size = 2
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
[Makefile]
|
||||
indent_style = tab
|
||||
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -46,7 +46,7 @@ Fixes #
|
||||
- [ ] My commit messages follow [Conventional Commits](https://www.conventionalcommits.org/) (`fix(scope):`, `feat(scope):`, etc.)
|
||||
- [ ] I searched for [existing PRs](https://github.com/NousResearch/hermes-agent/pulls) to make sure this isn't a duplicate
|
||||
- [ ] My PR contains **only** changes related to this fix/feature (no unrelated commits)
|
||||
- [ ] I've run `make check` (lint + test) and all checks pass
|
||||
- [ ] I've run `pytest tests/ -q` and all tests pass
|
||||
- [ ] I've added tests for my changes (required for bug fixes, strongly encouraged for features)
|
||||
- [ ] I've tested on my platform: <!-- e.g. Ubuntu 24.04, macOS 15.2, Windows 11 -->
|
||||
|
||||
|
||||
41
.github/workflows/tests.yml
vendored
41
.github/workflows/tests.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: CI
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -6,42 +6,37 @@ on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
# Cancel in-progress runs for the same PR/branch
|
||||
concurrency:
|
||||
group: ci-${{ github.ref }}
|
||||
group: tests-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
SRC: >-
|
||||
run_agent.py model_tools.py toolsets.py cli.py hermes_state.py batch_runner.py
|
||||
tools/ hermes_cli/ gateway/ agent/ cron/
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 3
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
- run: uvx ruff check $SRC
|
||||
- run: uvx ruff format --check $SRC
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
- run: uv python install 3.11
|
||||
- run: |
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
|
||||
- name: Set up Python 3.11
|
||||
run: uv python install 3.11
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all,dev]"
|
||||
- run: |
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python -m pytest tests/ -q --ignore=tests/integration --tb=short
|
||||
env:
|
||||
# Ensure tests don't accidentally call real APIs
|
||||
OPENROUTER_API_KEY: ""
|
||||
OPENAI_API_KEY: ""
|
||||
NOUS_API_KEY: ""
|
||||
|
||||
78
.gitignore
vendored
78
.gitignore
vendored
@@ -1,53 +1,51 @@
|
||||
# Python
|
||||
/venv/
|
||||
/_pycache/
|
||||
*.pyc*
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Environments
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# Tools
|
||||
.ruff_cache/
|
||||
.mypy_cache/
|
||||
.pytest_cache/
|
||||
|
||||
# Editors
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Secrets & config
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
*.pem
|
||||
*.ppk
|
||||
|
||||
# Node
|
||||
node_modules/
|
||||
|
||||
# Project-specific
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
.env.development
|
||||
.env.test
|
||||
export*
|
||||
__pycache__/model_tools.cpython-310.pyc
|
||||
__pycache__/web_tools.cpython-310.pyc
|
||||
logs/
|
||||
data/
|
||||
.pytest_cache/
|
||||
tmp/
|
||||
wandb/
|
||||
images/
|
||||
browser-use/
|
||||
agent-browser/
|
||||
source-data/
|
||||
testlogs/
|
||||
ignored/
|
||||
.worktrees/
|
||||
temp_vision_images/
|
||||
cli-config.yaml
|
||||
skills/.hub/
|
||||
hermes-*/*
|
||||
examples/
|
||||
export*
|
||||
privvy*
|
||||
run_datagen_*.sh
|
||||
tests/quick_test_dataset.jsonl
|
||||
tests/sample_dataset.jsonl
|
||||
run_datagen_kimik2-thinking.sh
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
run_datagen_sonnet.sh
|
||||
source-data/*
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
data/*
|
||||
node_modules/
|
||||
browser-use/
|
||||
agent-browser/
|
||||
# Private keys
|
||||
*.ppk
|
||||
*.pem
|
||||
privvy*
|
||||
images/
|
||||
__pycache__/
|
||||
hermes_agent.egg-info/
|
||||
wandb/
|
||||
testlogs
|
||||
|
||||
# CLI config (may contain sensitive SSH paths)
|
||||
cli-config.yaml
|
||||
|
||||
# Skills Hub state (lives in ~/.hermes/skills/.hub/ at runtime, but just in case)
|
||||
skills/.hub/
|
||||
ignored/
|
||||
.worktrees/
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.15.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-merge-conflict
|
||||
- id: check-yaml
|
||||
args: [--allow-multiple-documents]
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=500]
|
||||
23
AGENTS.md
23
AGENTS.md
@@ -5,8 +5,7 @@ Instructions for AI coding assistants and developers working on the hermes-agent
|
||||
## Development Environment
|
||||
|
||||
```bash
|
||||
make setup # First time: creates .venv, installs deps, sets up pre-commit
|
||||
source .venv/bin/activate
|
||||
source .venv/bin/activate # ALWAYS activate before running Python
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
@@ -229,27 +228,15 @@ The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HER
|
||||
|
||||
---
|
||||
|
||||
## Development Commands
|
||||
|
||||
```bash
|
||||
make setup # First time: .venv + deps + pre-commit hooks
|
||||
make check # Lint + test (mirrors CI — run before pushing)
|
||||
make lint # Ruff check
|
||||
make fmt # Ruff format + auto-fix
|
||||
make test # Full test suite (~2500 tests, ~2 min)
|
||||
make test-fast # Tests with fail-fast (-x)
|
||||
make test-watch # Rerun tests on file changes
|
||||
make dev-cli # Auto-restart CLI on file changes
|
||||
make dev-gateway # Auto-restart gateway on file changes
|
||||
```
|
||||
|
||||
For targeted testing, use `pytest` directly:
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
python -m pytest tests/ -q # Full suite (~2500 tests, ~2 min)
|
||||
python -m pytest tests/test_model_tools.py -q # Toolset resolution
|
||||
python -m pytest tests/test_cli_init.py -q # CLI config loading
|
||||
python -m pytest tests/gateway/ -q # Gateway tests
|
||||
python -m pytest tests/tools/ -q # Tool-level tests
|
||||
```
|
||||
|
||||
Formatting is enforced by **ruff** (config in `pyproject.toml`). Pre-commit hooks run on every commit.
|
||||
Always run the full suite before pushing changes.
|
||||
|
||||
@@ -65,7 +65,18 @@ If your skill is specialized, community-contributed, or niche, it's better suite
|
||||
```bash
|
||||
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
make setup # creates .venv, installs all deps
|
||||
|
||||
# Create venv with Python 3.11
|
||||
uv venv venv --python 3.11
|
||||
export VIRTUAL_ENV="$(pwd)/venv"
|
||||
|
||||
# Install with all extras (messaging, cron, CLI menus, dev tools)
|
||||
uv pip install -e ".[all,dev]"
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
uv pip install -e "./tinker-atropos"
|
||||
|
||||
# Optional: browser tools
|
||||
npm install
|
||||
```
|
||||
|
||||
### Configure for development
|
||||
@@ -79,16 +90,22 @@ touch ~/.hermes/.env
|
||||
echo 'OPENROUTER_API_KEY=sk-or-v1-your-key' >> ~/.hermes/.env
|
||||
```
|
||||
|
||||
### Common commands
|
||||
### Run
|
||||
|
||||
```bash
|
||||
make test # run unit tests
|
||||
make lint # ruff check
|
||||
make fmt # ruff format + fix
|
||||
make check # lint + test (same as CI)
|
||||
make dev-cli # auto-restart hermes CLI on file changes
|
||||
make dev-gateway # auto-restart gateway on file changes
|
||||
make test-watch # rerun tests on file changes
|
||||
# Symlink for global access
|
||||
mkdir -p ~/.local/bin
|
||||
ln -sf "$(pwd)/venv/bin/hermes" ~/.local/bin/hermes
|
||||
|
||||
# Verify
|
||||
hermes doctor
|
||||
hermes chat -q "Hello"
|
||||
```
|
||||
|
||||
### Run tests
|
||||
|
||||
```bash
|
||||
pytest tests/ -v
|
||||
```
|
||||
|
||||
---
|
||||
@@ -210,7 +227,7 @@ User message → AIAgent._run_agent_loop()
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Formatting**: Enforced by **ruff** (config in `pyproject.toml`). Run `make fmt` to auto-fix, `make lint` to check. Pre-commit hooks handle this automatically.
|
||||
- **PEP 8** with practical exceptions (we don't enforce strict line length)
|
||||
- **Comments**: Only when explaining non-obvious intent, trade-offs, or API quirks. Don't narrate what the code does — `# increment counter` adds nothing
|
||||
- **Error handling**: Catch specific exceptions. Log with `logger.warning()`/`logger.error()` — use `exc_info=True` for unexpected errors so stack traces appear in logs
|
||||
- **Cross-platform**: Never assume Unix. See [Cross-Platform Compatibility](#cross-platform-compatibility)
|
||||
@@ -440,7 +457,7 @@ refactor/description # Code restructuring
|
||||
|
||||
### Before submitting
|
||||
|
||||
1. **Run checks**: `make check` (lint + test — same as CI)
|
||||
1. **Run tests**: `pytest tests/ -v`
|
||||
2. **Test manually**: Run `hermes` and exercise the code path you changed
|
||||
3. **Check cross-platform impact**: If you touch file I/O, process management, or terminal handling, consider Windows and macOS
|
||||
4. **Keep PRs focused**: One logical change per PR. Don't mix a bug fix with a refactor with a new feature.
|
||||
|
||||
69
Makefile
69
Makefile
@@ -1,69 +0,0 @@
|
||||
.DEFAULT_GOAL := help
|
||||
SHELL := /bin/bash
|
||||
VENV := .venv
|
||||
UV := uv
|
||||
|
||||
SRC := run_agent.py model_tools.py toolsets.py cli.py hermes_state.py batch_runner.py \
|
||||
tools/ hermes_cli/ gateway/ agent/ cron/
|
||||
|
||||
# ─── Setup ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: setup sync clean
|
||||
|
||||
setup: ## Full dev setup (venv + deps + pre-commit)
|
||||
$(UV) venv $(VENV) --python 3.11
|
||||
. $(VENV)/bin/activate && $(UV) pip install -e ".[all,dev]"
|
||||
. $(VENV)/bin/activate && $(UV) pip install -e "./mini-swe-agent"
|
||||
. $(VENV)/bin/activate && pre-commit install
|
||||
@echo "\n✅ Setup complete. Run: source $(VENV)/bin/activate"
|
||||
|
||||
sync: ## Reinstall deps into existing venv
|
||||
. $(VENV)/bin/activate && $(UV) pip install -e ".[all,dev]"
|
||||
|
||||
clean: ## Remove build artifacts and caches
|
||||
rm -rf .ruff_cache .mypy_cache .pytest_cache dist build *.egg-info
|
||||
find . -type d -name __pycache__ -not -path "./.venv/*" -exec rm -rf {} +
|
||||
|
||||
# ─── Quality ────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: lint fmt check
|
||||
|
||||
lint: ## Check lint + formatting (no changes)
|
||||
. $(VENV)/bin/activate && ruff check $(SRC)
|
||||
. $(VENV)/bin/activate && ruff format --check $(SRC)
|
||||
|
||||
fmt: ## Auto-fix lint + format
|
||||
. $(VENV)/bin/activate && ruff format $(SRC)
|
||||
. $(VENV)/bin/activate && ruff check --fix $(SRC)
|
||||
|
||||
check: lint test ## Lint + test (mirrors CI)
|
||||
|
||||
# ─── Test ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: test test-fast test-watch
|
||||
|
||||
test: ## Run full test suite
|
||||
. $(VENV)/bin/activate && python -m pytest tests/ -q --ignore=tests/integration --tb=short
|
||||
|
||||
test-fast: ## Run tests with fail-fast
|
||||
. $(VENV)/bin/activate && python -m pytest tests/ -q --ignore=tests/integration --tb=short -x
|
||||
|
||||
test-watch: ## Rerun tests on file changes
|
||||
. $(VENV)/bin/activate && python -m watchfiles "python -m pytest tests/ -q --ignore=tests/integration --tb=short -x" $(SRC) tests/
|
||||
|
||||
# ─── Dev Servers ────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: dev-cli dev-gateway
|
||||
|
||||
dev-cli: ## Auto-restart CLI on file changes
|
||||
. $(VENV)/bin/activate && python -m watchfiles "python -m hermes_cli.main" $(SRC)
|
||||
|
||||
dev-gateway: ## Auto-restart gateway on file changes
|
||||
. $(VENV)/bin/activate && python -m watchfiles "python -m gateway.run" $(SRC)
|
||||
|
||||
# ─── Misc ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: help
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'
|
||||
@@ -95,8 +95,12 @@ Quick start for contributors:
|
||||
```bash
|
||||
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
make setup # creates .venv, installs everything
|
||||
make check # lint + test (same as CI)
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all,dev]"
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
python -m pytest tests/ -q
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -34,7 +34,7 @@ import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -43,7 +43,7 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: dict[str, str] = {
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"zai": "glm-4.5-flash",
|
||||
"kimi-coding": "kimi-k2-turbo-preview",
|
||||
"minimax": "MiniMax-M2.5-highspeed",
|
||||
@@ -102,7 +102,7 @@ def _convert_content_for_responses(content: Any) -> Any:
|
||||
if not isinstance(content, list):
|
||||
return str(content) if content else ""
|
||||
|
||||
converted: list[dict[str, Any]] = []
|
||||
converted: List[Dict[str, Any]] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
@@ -113,7 +113,7 @@ def _convert_content_for_responses(content: Any) -> Any:
|
||||
# chat.completions nests the URL: {"image_url": {"url": "..."}}
|
||||
image_data = part.get("image_url", {})
|
||||
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
|
||||
entry: dict[str, Any] = {"type": "input_image", "image_url": url}
|
||||
entry: Dict[str, Any] = {"type": "input_image", "image_url": url}
|
||||
# Preserve detail if specified
|
||||
detail = image_data.get("detail") if isinstance(image_data, dict) else None
|
||||
if detail:
|
||||
@@ -148,21 +148,19 @@ class _CodexCompletionsAdapter:
|
||||
# Convert chat.completions multimodal content blocks to Responses
|
||||
# API format (input_text / input_image instead of text / image_url).
|
||||
instructions = "You are a helpful assistant."
|
||||
input_msgs: list[dict[str, Any]] = []
|
||||
input_msgs: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content") or ""
|
||||
if role == "system":
|
||||
instructions = content if isinstance(content, str) else str(content)
|
||||
else:
|
||||
input_msgs.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": _convert_content_for_responses(content),
|
||||
}
|
||||
)
|
||||
input_msgs.append({
|
||||
"role": role,
|
||||
"content": _convert_content_for_responses(content),
|
||||
})
|
||||
|
||||
resp_kwargs: dict[str, Any] = {
|
||||
resp_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"instructions": instructions,
|
||||
"input": input_msgs or [{"role": "user", "content": ""}],
|
||||
@@ -181,20 +179,18 @@ class _CodexCompletionsAdapter:
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
converted.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
})
|
||||
if converted:
|
||||
resp_kwargs["tools"] = converted
|
||||
|
||||
# Stream and collect the response
|
||||
text_parts: list[str] = []
|
||||
tool_calls_raw: list[Any] = []
|
||||
text_parts: List[str] = []
|
||||
tool_calls_raw: List[Any] = []
|
||||
usage = None
|
||||
|
||||
try:
|
||||
@@ -212,16 +208,14 @@ class _CodexCompletionsAdapter:
|
||||
if ptype in ("output_text", "text"):
|
||||
text_parts.append(getattr(part, "text", ""))
|
||||
elif item_type == "function_call":
|
||||
tool_calls_raw.append(
|
||||
SimpleNamespace(
|
||||
id=getattr(item, "call_id", ""),
|
||||
type="function",
|
||||
function=SimpleNamespace(
|
||||
name=getattr(item, "name", ""),
|
||||
arguments=getattr(item, "arguments", "{}"),
|
||||
),
|
||||
)
|
||||
)
|
||||
tool_calls_raw.append(SimpleNamespace(
|
||||
id=getattr(item, "call_id", ""),
|
||||
type="function",
|
||||
function=SimpleNamespace(
|
||||
name=getattr(item, "name", ""),
|
||||
arguments=getattr(item, "arguments", "{}"),
|
||||
),
|
||||
))
|
||||
|
||||
resp_usage = getattr(final, "usage", None)
|
||||
if resp_usage:
|
||||
@@ -291,7 +285,6 @@ class _AsyncCodexCompletionsAdapter:
|
||||
|
||||
async def create(self, **kwargs) -> Any:
|
||||
import asyncio
|
||||
|
||||
return await asyncio.to_thread(self._sync.create, **kwargs)
|
||||
|
||||
|
||||
@@ -311,7 +304,7 @@ class AsyncCodexAuxiliaryClient:
|
||||
self.base_url = sync_wrapper.base_url
|
||||
|
||||
|
||||
def _read_nous_auth() -> dict | None:
|
||||
def _read_nous_auth() -> Optional[dict]:
|
||||
"""Read and validate ~/.hermes/auth.json for an active Nous provider.
|
||||
|
||||
Returns the provider state dict if Nous is active with tokens,
|
||||
@@ -343,11 +336,10 @@ def _nous_base_url() -> str:
|
||||
return os.getenv("NOUS_INFERENCE_BASE_URL", _NOUS_DEFAULT_BASE_URL)
|
||||
|
||||
|
||||
def _read_codex_access_token() -> str | None:
|
||||
def _read_codex_access_token() -> Optional[str]:
|
||||
"""Read a valid Codex OAuth access token from Hermes auth store (~/.hermes/auth.json)."""
|
||||
try:
|
||||
from hermes_cli.auth import _read_codex_tokens
|
||||
|
||||
data = _read_codex_tokens()
|
||||
tokens = data.get("tokens", {})
|
||||
access_token = tokens.get("access_token")
|
||||
@@ -359,7 +351,7 @@ def _read_codex_access_token() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_api_key_provider() -> tuple[OpenAI | None, str | None]:
|
||||
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Try each API-key provider in PROVIDER_REGISTRY order.
|
||||
|
||||
Returns (client, model) for the first provider whose env var is set,
|
||||
@@ -406,7 +398,6 @@ def _resolve_api_key_provider() -> tuple[OpenAI | None, str | None]:
|
||||
|
||||
# ── Provider resolution helpers ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_auxiliary_provider(task: str = "") -> str:
|
||||
"""Read the provider override for a specific auxiliary task.
|
||||
|
||||
@@ -422,15 +413,16 @@ def _get_auxiliary_provider(task: str = "") -> str:
|
||||
return "auto"
|
||||
|
||||
|
||||
def _try_openrouter() -> tuple[OpenAI | None, str | None]:
|
||||
def _try_openrouter() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not or_key:
|
||||
return None, None
|
||||
logger.debug("Auxiliary client: OpenRouter")
|
||||
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL, default_headers=_OR_HEADERS), _OPENROUTER_MODEL
|
||||
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL,
|
||||
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
|
||||
|
||||
|
||||
def _try_nous() -> tuple[OpenAI | None, str | None]:
|
||||
def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
nous = _read_nous_auth()
|
||||
if not nous:
|
||||
return None, None
|
||||
@@ -443,7 +435,7 @@ def _try_nous() -> tuple[OpenAI | None, str | None]:
|
||||
)
|
||||
|
||||
|
||||
def _try_custom_endpoint() -> tuple[OpenAI | None, str | None]:
|
||||
def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
custom_base = os.getenv("OPENAI_BASE_URL")
|
||||
custom_key = os.getenv("OPENAI_API_KEY")
|
||||
if not custom_base or not custom_key:
|
||||
@@ -453,7 +445,7 @@ def _try_custom_endpoint() -> tuple[OpenAI | None, str | None]:
|
||||
return OpenAI(api_key=custom_key, base_url=custom_base), model
|
||||
|
||||
|
||||
def _try_codex() -> tuple[Any | None, str | None]:
|
||||
def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
|
||||
codex_token = _read_codex_access_token()
|
||||
if not codex_token:
|
||||
return None, None
|
||||
@@ -462,7 +454,7 @@ def _try_codex() -> tuple[Any | None, str | None]:
|
||||
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> tuple[OpenAI | None, str | None]:
|
||||
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
|
||||
if forced == "openrouter":
|
||||
client, model = _try_openrouter()
|
||||
@@ -496,9 +488,10 @@ def _resolve_forced_provider(forced: str) -> tuple[OpenAI | None, str | None]:
|
||||
return None, None
|
||||
|
||||
|
||||
def _resolve_auto() -> tuple[OpenAI | None, str | None]:
|
||||
def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None."""
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint, _try_codex, _resolve_api_key_provider):
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint,
|
||||
_try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
@@ -508,8 +501,7 @@ def _resolve_auto() -> tuple[OpenAI | None, str | None]:
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_text_auxiliary_client(task: str = "") -> tuple[OpenAI | None, str | None]:
|
||||
def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, default_model_slug) for text-only auxiliary tasks.
|
||||
|
||||
Args:
|
||||
@@ -552,7 +544,7 @@ def get_async_text_auxiliary_client(task: str = ""):
|
||||
return AsyncOpenAI(**async_kwargs), model
|
||||
|
||||
|
||||
def get_vision_auxiliary_client() -> tuple[OpenAI | None, str | None]:
|
||||
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks.
|
||||
|
||||
Checks AUXILIARY_VISION_PROVIDER for a forced provider, otherwise
|
||||
@@ -568,21 +560,18 @@ def get_vision_auxiliary_client() -> tuple[OpenAI | None, str | None]:
|
||||
forced = _get_auxiliary_provider("vision")
|
||||
if forced != "auto":
|
||||
return _resolve_forced_provider(forced)
|
||||
# Auto: try providers known to support multimodal first, then fall
|
||||
# back to the user's custom endpoint. Many local models (Qwen-VL,
|
||||
# LLaVA, Pixtral, etc.) support vision — skipping them entirely
|
||||
# caused silent failures for local-only users.
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex, _try_custom_endpoint):
|
||||
# Auto: only multimodal-capable providers
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.debug("Auxiliary vision client: none available")
|
||||
logger.debug("Auxiliary vision client: none available (auto only tries OpenRouter/Nous/Codex)")
|
||||
return None, None
|
||||
|
||||
|
||||
def get_auxiliary_extra_body() -> dict:
|
||||
"""Return extra_body kwargs for auxiliary API calls.
|
||||
|
||||
|
||||
Includes Nous Portal product tags when the auxiliary client is backed
|
||||
by Nous Portal. Returns empty dict otherwise.
|
||||
"""
|
||||
@@ -591,7 +580,7 @@ def get_auxiliary_extra_body() -> dict:
|
||||
|
||||
def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
"""Return the correct max tokens kwarg for the auxiliary client's provider.
|
||||
|
||||
|
||||
OpenRouter and local models use 'max_tokens'. Direct OpenAI with newer
|
||||
models (gpt-4o, o-series, gpt-5+) requires 'max_completion_tokens'.
|
||||
The Codex adapter translates max_tokens internally, so we use max_tokens
|
||||
@@ -600,6 +589,8 @@ def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
custom_base = os.getenv("OPENAI_BASE_URL", "")
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
# Only use max_completion_tokens for direct OpenAI custom endpoints
|
||||
if not or_key and _read_nous_auth() is None and "api.openai.com" in custom_base.lower():
|
||||
if (not or_key
|
||||
and _read_nous_auth() is None
|
||||
and "api.openai.com" in custom_base.lower()):
|
||||
return {"max_completion_tokens": value}
|
||||
return {"max_tokens": value}
|
||||
|
||||
@@ -7,12 +7,12 @@ protecting head and tail context.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
from agent.model_metadata import (
|
||||
estimate_messages_tokens_rough,
|
||||
get_model_context_length,
|
||||
estimate_messages_tokens_rough,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,7 +56,7 @@ class ContextCompressor:
|
||||
self.client, default_model = get_text_auxiliary_client("compression")
|
||||
self.summary_model = summary_model_override or default_model
|
||||
|
||||
def update_from_response(self, usage: dict[str, Any]):
|
||||
def update_from_response(self, usage: Dict[str, Any]):
|
||||
"""Update tracked token usage from API response."""
|
||||
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
self.last_completion_tokens = usage.get("completion_tokens", 0)
|
||||
@@ -67,12 +67,12 @@ class ContextCompressor:
|
||||
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
|
||||
return tokens >= self.threshold_tokens
|
||||
|
||||
def should_compress_preflight(self, messages: list[dict[str, Any]]) -> bool:
|
||||
def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool:
|
||||
"""Quick pre-flight check using rough estimate (before API call)."""
|
||||
rough_estimate = estimate_messages_tokens_rough(messages)
|
||||
return rough_estimate >= self.threshold_tokens
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get current compression status for display/logging."""
|
||||
return {
|
||||
"last_prompt_tokens": self.last_prompt_tokens,
|
||||
@@ -82,7 +82,7 @@ class ContextCompressor:
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: list[dict[str, Any]]) -> str | None:
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Generate a concise summary of conversation turns.
|
||||
|
||||
Tries the auxiliary model first, then falls back to the user's main
|
||||
@@ -140,9 +140,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
logging.warning(f"Main model summary also failed: {fallback_err}")
|
||||
|
||||
# 3. All models failed — return None so the caller drops turns without a summary
|
||||
logging.warning(
|
||||
"Context compression: no model available for summary. Middle turns will be dropped without summary."
|
||||
)
|
||||
logging.warning("Context compression: no model available for summary. Middle turns will be dropped without summary.")
|
||||
return None
|
||||
|
||||
def _call_summary_model(self, client, model: str, prompt: str) -> str:
|
||||
@@ -188,14 +186,12 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
|
||||
# Don't fallback to the same provider that just failed
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
if custom_base.rstrip("/") == OPENROUTER_BASE_URL.rstrip("/"):
|
||||
return None, None
|
||||
|
||||
model = os.getenv("LLM_MODEL") or os.getenv("OPENAI_MODEL") or self.model
|
||||
try:
|
||||
from openai import OpenAI as _OpenAI
|
||||
|
||||
client = _OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
logger.debug("Built fallback auxiliary client: %s via %s", model, custom_base)
|
||||
return client, model
|
||||
@@ -214,7 +210,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
return tc.get("id", "")
|
||||
return getattr(tc, "id", "") or ""
|
||||
|
||||
def _sanitize_tool_pairs(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Fix orphaned tool_call / tool_result pairs after compression.
|
||||
|
||||
Two failure modes:
|
||||
@@ -247,7 +243,8 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
orphaned_results = result_call_ids - surviving_call_ids
|
||||
if orphaned_results:
|
||||
messages = [
|
||||
m for m in messages if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||
m for m in messages
|
||||
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||
]
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results))
|
||||
@@ -255,27 +252,25 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
# 2. Add stub results for assistant tool_calls whose results were dropped
|
||||
missing_results = surviving_call_ids - result_call_ids
|
||||
if missing_results:
|
||||
patched: list[dict[str, Any]] = []
|
||||
patched: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = self._get_tool_call_id(tc)
|
||||
if cid in missing_results:
|
||||
patched.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "[Result from earlier conversation — see context summary above]",
|
||||
"tool_call_id": cid,
|
||||
}
|
||||
)
|
||||
patched.append({
|
||||
"role": "tool",
|
||||
"content": "[Result from earlier conversation — see context summary above]",
|
||||
"tool_call_id": cid,
|
||||
})
|
||||
messages = patched
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results))
|
||||
|
||||
return messages
|
||||
|
||||
def _align_boundary_forward(self, messages: list[dict[str, Any]], idx: int) -> int:
|
||||
def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
"""Push a compress-start boundary forward past any orphan tool results.
|
||||
|
||||
If ``messages[idx]`` is a tool result, slide forward until we hit a
|
||||
@@ -285,7 +280,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
idx += 1
|
||||
return idx
|
||||
|
||||
def _align_boundary_backward(self, messages: list[dict[str, Any]], idx: int) -> int:
|
||||
def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
"""Pull a compress-end boundary backward to avoid splitting a
|
||||
tool_call / result group.
|
||||
|
||||
@@ -303,7 +298,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
idx -= 1
|
||||
return idx
|
||||
|
||||
def compress(self, messages: list[dict[str, Any]], current_tokens: int = None) -> list[dict[str, Any]]:
|
||||
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
|
||||
"""Compress conversation messages by summarizing middle turns.
|
||||
|
||||
Keeps first N + last N turns, summarizes everything in between.
|
||||
@@ -313,9 +308,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
n_messages = len(messages)
|
||||
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
|
||||
if not self.quiet_mode:
|
||||
print(
|
||||
f"⚠️ Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})"
|
||||
)
|
||||
print(f"⚠️ Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})")
|
||||
return messages
|
||||
|
||||
compress_start = self.protect_first_n
|
||||
@@ -330,20 +323,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
display_tokens = (
|
||||
current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
)
|
||||
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(
|
||||
f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)"
|
||||
)
|
||||
print(
|
||||
f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent * 100:.0f}% = {self.threshold_tokens:,})"
|
||||
)
|
||||
print(f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)")
|
||||
print(f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})")
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f" 🗜️ Summarizing turns {compress_start + 1}-{compress_end} ({len(turns_to_summarize)} turns)")
|
||||
print(f" 🗜️ Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)")
|
||||
|
||||
summary = self._generate_summary(turns_to_summarize)
|
||||
|
||||
@@ -351,9 +338,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
for i in range(compress_start):
|
||||
msg = messages[i].copy()
|
||||
if i == 0 and msg.get("role") == "system" and self.compression_count == 0:
|
||||
msg["content"] = (
|
||||
msg.get("content") or ""
|
||||
) + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
|
||||
msg["content"] = (msg.get("content") or "") + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
|
||||
compressed.append(msg)
|
||||
|
||||
if summary:
|
||||
|
||||
282
agent/display.py
282
agent/display.py
@@ -6,6 +6,7 @@ Used by AIAgent._execute_tool_calls for CLI feedback.
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -19,31 +20,19 @@ _RESET = "\033[0m"
|
||||
# Tool preview (one-line summary of a tool call's primary argument)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
"""Build a short preview of a tool call's primary argument for display."""
|
||||
primary_args = {
|
||||
"terminal": "command",
|
||||
"web_search": "query",
|
||||
"web_extract": "urls",
|
||||
"read_file": "path",
|
||||
"write_file": "path",
|
||||
"patch": "path",
|
||||
"search_files": "pattern",
|
||||
"browser_navigate": "url",
|
||||
"browser_click": "ref",
|
||||
"browser_type": "text",
|
||||
"image_generate": "prompt",
|
||||
"text_to_speech": "text",
|
||||
"vision_analyze": "question",
|
||||
"mixture_of_agents": "user_prompt",
|
||||
"skill_view": "name",
|
||||
"skills_list": "category",
|
||||
"terminal": "command", "web_search": "query", "web_extract": "urls",
|
||||
"read_file": "path", "write_file": "path", "patch": "path",
|
||||
"search_files": "pattern", "browser_navigate": "url",
|
||||
"browser_click": "ref", "browser_type": "text",
|
||||
"image_generate": "prompt", "text_to_speech": "text",
|
||||
"vision_analyze": "question", "mixture_of_agents": "user_prompt",
|
||||
"skill_view": "name", "skills_list": "category",
|
||||
"schedule_cronjob": "name",
|
||||
"execute_code": "code",
|
||||
"delegate_task": "goal",
|
||||
"clarify": "question",
|
||||
"skill_manage": "name",
|
||||
"execute_code": "code", "delegate_task": "goal",
|
||||
"clarify": "question", "skill_manage": "name",
|
||||
}
|
||||
|
||||
if tool_name == "process":
|
||||
@@ -72,18 +61,18 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
|
||||
if tool_name == "session_search":
|
||||
query = args.get("query", "")
|
||||
return f'recall: "{query[:25]}{"..." if len(query) > 25 else ""}"'
|
||||
return f"recall: \"{query[:25]}{'...' if len(query) > 25 else ''}\""
|
||||
|
||||
if tool_name == "memory":
|
||||
action = args.get("action", "")
|
||||
target = args.get("target", "")
|
||||
if action == "add":
|
||||
content = args.get("content", "")
|
||||
return f'+{target}: "{content[:25]}{"..." if len(content) > 25 else ""}"'
|
||||
return f"+{target}: \"{content[:25]}{'...' if len(content) > 25 else ''}\""
|
||||
elif action == "replace":
|
||||
return f'~{target}: "{args.get("old_text", "")[:20]}"'
|
||||
return f"~{target}: \"{args.get('old_text', '')[:20]}\""
|
||||
elif action == "remove":
|
||||
return f'-{target}: "{args.get("old_text", "")[:20]}"'
|
||||
return f"-{target}: \"{args.get('old_text', '')[:20]}\""
|
||||
return action
|
||||
|
||||
if tool_name == "send_message":
|
||||
@@ -91,7 +80,7 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
msg = args.get("message", "")
|
||||
if len(msg) > 20:
|
||||
msg = msg[:17] + "..."
|
||||
return f'to {target}: "{msg}"'
|
||||
return f"to {target}: \"{msg}\""
|
||||
|
||||
if tool_name.startswith("rl_"):
|
||||
rl_previews = {
|
||||
@@ -126,7 +115,7 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
if not preview:
|
||||
return None
|
||||
if len(preview) > max_len:
|
||||
preview = preview[: max_len - 3] + "..."
|
||||
preview = preview[:max_len - 3] + "..."
|
||||
return preview
|
||||
|
||||
|
||||
@@ -134,74 +123,41 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
# KawaiiSpinner
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class KawaiiSpinner:
|
||||
"""Animated spinner with kawaii faces for CLI feedback during tool execution."""
|
||||
|
||||
SPINNERS = {
|
||||
"dots": ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"],
|
||||
"bounce": ["⠁", "⠂", "⠄", "⡀", "⢀", "⠠", "⠐", "⠈"],
|
||||
"grow": ["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█", "▇", "▆", "▅", "▄", "▃", "▂"],
|
||||
"arrows": ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"],
|
||||
"star": ["✶", "✷", "✸", "✹", "✺", "✹", "✸", "✷"],
|
||||
"moon": ["🌑", "🌒", "🌓", "🌔", "🌕", "🌖", "🌗", "🌘"],
|
||||
"pulse": ["◜", "◠", "◝", "◞", "◡", "◟"],
|
||||
"brain": ["🧠", "💭", "💡", "✨", "💫", "🌟", "💡", "💭"],
|
||||
"sparkle": ["⁺", "˚", "*", "✧", "✦", "✧", "*", "˚"],
|
||||
'dots': ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'],
|
||||
'bounce': ['⠁', '⠂', '⠄', '⡀', '⢀', '⠠', '⠐', '⠈'],
|
||||
'grow': ['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█', '▇', '▆', '▅', '▄', '▃', '▂'],
|
||||
'arrows': ['←', '↖', '↑', '↗', '→', '↘', '↓', '↙'],
|
||||
'star': ['✶', '✷', '✸', '✹', '✺', '✹', '✸', '✷'],
|
||||
'moon': ['🌑', '🌒', '🌓', '🌔', '🌕', '🌖', '🌗', '🌘'],
|
||||
'pulse': ['◜', '◠', '◝', '◞', '◡', '◟'],
|
||||
'brain': ['🧠', '💭', '💡', '✨', '💫', '🌟', '💡', '💭'],
|
||||
'sparkle': ['⁺', '˚', '*', '✧', '✦', '✧', '*', '˚'],
|
||||
}
|
||||
|
||||
KAWAII_WAITING = [
|
||||
"(。◕‿◕。)",
|
||||
"(◕‿◕✿)",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿◠‿◠)",
|
||||
"( ˘▽˘)っ",
|
||||
"♪(´ε` )",
|
||||
"(◕ᴗ◕✿)",
|
||||
"ヾ(^∇^)",
|
||||
"(≧◡≦)",
|
||||
"(★ω★)",
|
||||
"(。◕‿◕。)", "(◕‿◕✿)", "٩(◕‿◕。)۶", "(✿◠‿◠)", "( ˘▽˘)っ",
|
||||
"♪(´ε` )", "(◕ᴗ◕✿)", "ヾ(^∇^)", "(≧◡≦)", "(★ω★)",
|
||||
]
|
||||
|
||||
KAWAII_THINKING = [
|
||||
"(。•́︿•̀。)",
|
||||
"(◔_◔)",
|
||||
"(¬‿¬)",
|
||||
"( •_•)>⌐■-■",
|
||||
"(⌐■_■)",
|
||||
"(´・_・`)",
|
||||
"◉_◉",
|
||||
"(°ロ°)",
|
||||
"( ˘⌣˘)♡",
|
||||
"ヽ(>∀<☆)☆",
|
||||
"٩(๑❛ᴗ❛๑)۶",
|
||||
"(⊙_⊙)",
|
||||
"(¬_¬)",
|
||||
"( ͡° ͜ʖ ͡°)",
|
||||
"ಠ_ಠ",
|
||||
"(。•́︿•̀。)", "(◔_◔)", "(¬‿¬)", "( •_•)>⌐■-■", "(⌐■_■)",
|
||||
"(´・_・`)", "◉_◉", "(°ロ°)", "( ˘⌣˘)♡", "ヽ(>∀<☆)☆",
|
||||
"٩(๑❛ᴗ❛๑)۶", "(⊙_⊙)", "(¬_¬)", "( ͡° ͜ʖ ͡°)", "ಠ_ಠ",
|
||||
]
|
||||
|
||||
THINKING_VERBS = [
|
||||
"pondering",
|
||||
"contemplating",
|
||||
"musing",
|
||||
"cogitating",
|
||||
"ruminating",
|
||||
"deliberating",
|
||||
"mulling",
|
||||
"reflecting",
|
||||
"processing",
|
||||
"reasoning",
|
||||
"analyzing",
|
||||
"computing",
|
||||
"synthesizing",
|
||||
"formulating",
|
||||
"brainstorming",
|
||||
"pondering", "contemplating", "musing", "cogitating", "ruminating",
|
||||
"deliberating", "mulling", "reflecting", "processing", "reasoning",
|
||||
"analyzing", "computing", "synthesizing", "formulating", "brainstorming",
|
||||
]
|
||||
|
||||
def __init__(self, message: str = "", spinner_type: str = "dots"):
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots'):
|
||||
self.message = message
|
||||
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS["dots"])
|
||||
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots'])
|
||||
self.running = False
|
||||
self.thread = None
|
||||
self.frame_idx = 0
|
||||
@@ -211,7 +167,7 @@ class KawaiiSpinner:
|
||||
# child agents can replace sys.stdout with a black hole.
|
||||
self._out = sys.stdout
|
||||
|
||||
def _write(self, text: str, end: str = "\n", flush: bool = False):
|
||||
def _write(self, text: str, end: str = '\n', flush: bool = False):
|
||||
"""Write to the stdout captured at spinner creation time."""
|
||||
try:
|
||||
self._out.write(text + end)
|
||||
@@ -229,7 +185,7 @@ class KawaiiSpinner:
|
||||
elapsed = time.time() - self.start_time
|
||||
line = f" {frame} {self.message} ({elapsed:.1f}s)"
|
||||
pad = max(self.last_line_len - len(line), 0)
|
||||
self._write(f"\r{line}{' ' * pad}", end="", flush=True)
|
||||
self._write(f"\r{line}{' ' * pad}", end='', flush=True)
|
||||
self.last_line_len = len(line)
|
||||
self.frame_idx += 1
|
||||
time.sleep(0.12)
|
||||
@@ -260,7 +216,7 @@ class KawaiiSpinner:
|
||||
# Clear spinner line with spaces (not \033[K) to avoid garbled escape
|
||||
# codes when prompt_toolkit's patch_stdout is active — same approach
|
||||
# as stop(). Then print text; spinner redraws on next tick.
|
||||
blanks = " " * max(self.last_line_len + 5, 40)
|
||||
blanks = ' ' * max(self.last_line_len + 5, 40)
|
||||
self._write(f"\r{blanks}\r {text}", flush=True)
|
||||
|
||||
def stop(self, final_message: str = None):
|
||||
@@ -269,8 +225,8 @@ class KawaiiSpinner:
|
||||
self.thread.join(timeout=0.5)
|
||||
# Clear the spinner line with spaces instead of \033[K to avoid
|
||||
# garbled escape codes when prompt_toolkit's patch_stdout is active.
|
||||
blanks = " " * max(self.last_line_len + 5, 40)
|
||||
self._write(f"\r{blanks}\r", end="", flush=True)
|
||||
blanks = ' ' * max(self.last_line_len + 5, 40)
|
||||
self._write(f"\r{blanks}\r", end='', flush=True)
|
||||
if final_message:
|
||||
self._write(f" {final_message}", flush=True)
|
||||
|
||||
@@ -288,110 +244,38 @@ class KawaiiSpinner:
|
||||
# =========================================================================
|
||||
|
||||
KAWAII_SEARCH = [
|
||||
"♪(´ε` )",
|
||||
"(。◕‿◕。)",
|
||||
"ヾ(^∇^)",
|
||||
"(◕ᴗ◕✿)",
|
||||
"( ˘▽˘)っ",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿◠‿◠)",
|
||||
"♪~(´ε` )",
|
||||
"(ノ´ヮ`)ノ*:・゚✧",
|
||||
"\(◎o◎)/",
|
||||
"♪(´ε` )", "(。◕‿◕。)", "ヾ(^∇^)", "(◕ᴗ◕✿)", "( ˘▽˘)っ",
|
||||
"٩(◕‿◕。)۶", "(✿◠‿◠)", "♪~(´ε` )", "(ノ´ヮ`)ノ*:・゚✧", "\(◎o◎)/",
|
||||
]
|
||||
KAWAII_READ = [
|
||||
"φ(゜▽゜*)♪",
|
||||
"( ˘▽˘)っ",
|
||||
"(⌐■_■)",
|
||||
"٩(。•́‿•̀。)۶",
|
||||
"(◕‿◕✿)",
|
||||
"ヾ(@⌒ー⌒@)ノ",
|
||||
"(✧ω✧)",
|
||||
"♪(๑ᴖ◡ᴖ๑)♪",
|
||||
"(≧◡≦)",
|
||||
"( ´ ▽ ` )ノ",
|
||||
"φ(゜▽゜*)♪", "( ˘▽˘)っ", "(⌐■_■)", "٩(。•́‿•̀。)۶", "(◕‿◕✿)",
|
||||
"ヾ(@⌒ー⌒@)ノ", "(✧ω✧)", "♪(๑ᴖ◡ᴖ๑)♪", "(≧◡≦)", "( ´ ▽ ` )ノ",
|
||||
]
|
||||
KAWAII_TERMINAL = [
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"(ノ°∀°)ノ",
|
||||
"٩(^ᴗ^)۶",
|
||||
"ヾ(⌐■_■)ノ♪",
|
||||
"(•̀ᴗ•́)و",
|
||||
"┗(^0^)┓",
|
||||
"(`・ω・´)",
|
||||
"\( ̄▽ ̄)/",
|
||||
"(ง •̀_•́)ง",
|
||||
"ヽ(´▽`)/",
|
||||
"ヽ(>∀<☆)ノ", "(ノ°∀°)ノ", "٩(^ᴗ^)۶", "ヾ(⌐■_■)ノ♪", "(•̀ᴗ•́)و",
|
||||
"┗(^0^)┓", "(`・ω・´)", "\( ̄▽ ̄)/", "(ง •̀_•́)ง", "ヽ(´▽`)/",
|
||||
]
|
||||
KAWAII_BROWSER = [
|
||||
"(ノ°∀°)ノ",
|
||||
"(☞゚ヮ゚)☞",
|
||||
"( ͡° ͜ʖ ͡°)",
|
||||
"┌( ಠ_ಠ)┘",
|
||||
"(⊙_⊙)?",
|
||||
"ヾ(•ω•`)o",
|
||||
"( ̄ω ̄)",
|
||||
"( ˇωˇ )",
|
||||
"(ᵔᴥᵔ)",
|
||||
"\(◎o◎)/",
|
||||
"(ノ°∀°)ノ", "(☞゚ヮ゚)☞", "( ͡° ͜ʖ ͡°)", "┌( ಠ_ಠ)┘", "(⊙_⊙)?",
|
||||
"ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᴥᵔ)", "\(◎o◎)/",
|
||||
]
|
||||
KAWAII_CREATE = [
|
||||
"✧*。٩(ˊᗜˋ*)و✧",
|
||||
"(ノ◕ヮ◕)ノ*:・゚✧",
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"٩(♡ε♡)۶",
|
||||
"(◕‿◕)♡",
|
||||
"✿◕ ‿ ◕✿",
|
||||
"(*≧▽≦)",
|
||||
"ヾ(^-^)ノ",
|
||||
"(☆▽☆)",
|
||||
"°˖✧◝(⁰▿⁰)◜✧˖°",
|
||||
"✧*。٩(ˊᗜˋ*)و✧", "(ノ◕ヮ◕)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "٩(♡ε♡)۶", "(◕‿◕)♡",
|
||||
"✿◕ ‿ ◕✿", "(*≧▽≦)", "ヾ(^-^)ノ", "(☆▽☆)", "°˖✧◝(⁰▿⁰)◜✧˖°",
|
||||
]
|
||||
KAWAII_SKILL = [
|
||||
"ヾ(@⌒ー⌒@)ノ",
|
||||
"(๑˃ᴗ˂)ﻭ",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿╹◡╹)",
|
||||
"ヽ(・∀・)ノ",
|
||||
"(ノ´ヮ`)ノ*:・゚✧",
|
||||
"♪(๑ᴖ◡ᴖ๑)♪",
|
||||
"(◠‿◠)",
|
||||
"٩(ˊᗜˋ*)و",
|
||||
"(^▽^)",
|
||||
"ヾ(^∇^)",
|
||||
"(★ω★)/",
|
||||
"٩(。•́‿•̀。)۶",
|
||||
"(◕ᴗ◕✿)",
|
||||
"\(◎o◎)/",
|
||||
"(✧ω✧)",
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"( ˘▽˘)っ",
|
||||
"(≧◡≦) ♡",
|
||||
"ヾ( ̄▽ ̄)",
|
||||
"ヾ(@⌒ー⌒@)ノ", "(๑˃ᴗ˂)ﻭ", "٩(◕‿◕。)۶", "(✿╹◡╹)", "ヽ(・∀・)ノ",
|
||||
"(ノ´ヮ`)ノ*:・゚✧", "♪(๑ᴖ◡ᴖ๑)♪", "(◠‿◠)", "٩(ˊᗜˋ*)و", "(^▽^)",
|
||||
"ヾ(^∇^)", "(★ω★)/", "٩(。•́‿•̀。)۶", "(◕ᴗ◕✿)", "\(◎o◎)/",
|
||||
"(✧ω✧)", "ヽ(>∀<☆)ノ", "( ˘▽˘)っ", "(≧◡≦) ♡", "ヾ( ̄▽ ̄)",
|
||||
]
|
||||
KAWAII_THINK = [
|
||||
"(っ°Д°;)っ",
|
||||
"(;′⌒`)",
|
||||
"(・_・ヾ",
|
||||
"( ´_ゝ`)",
|
||||
"( ̄ヘ ̄)",
|
||||
"(。-`ω´-)",
|
||||
"( ˘︹˘ )",
|
||||
"(¬_¬)",
|
||||
"ヽ(ー_ー )ノ",
|
||||
"(;一_一)",
|
||||
"(っ°Д°;)っ", "(;′⌒`)", "(・_・ヾ", "( ´_ゝ`)", "( ̄ヘ ̄)",
|
||||
"(。-`ω´-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )ノ", "(;一_一)",
|
||||
]
|
||||
KAWAII_GENERIC = [
|
||||
"♪(´ε` )",
|
||||
"(◕‿◕✿)",
|
||||
"ヾ(^∇^)",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿◠‿◠)",
|
||||
"(ノ´ヮ`)ノ*:・゚✧",
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"(☆▽☆)",
|
||||
"( ˘▽˘)っ",
|
||||
"(≧◡≦)",
|
||||
"♪(´ε` )", "(◕‿◕✿)", "ヾ(^∇^)", "٩(◕‿◕。)۶", "(✿◠‿◠)",
|
||||
"(ノ´ヮ`)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "(☆▽☆)", "( ˘▽˘)っ", "(≧◡≦)",
|
||||
]
|
||||
|
||||
|
||||
@@ -399,7 +283,6 @@ KAWAII_GENERIC = [
|
||||
# Cute tool message (completion line that replaces the spinner)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]:
|
||||
"""Inspect a tool result string for signs of failure.
|
||||
|
||||
@@ -438,10 +321,7 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
|
||||
|
||||
|
||||
def get_cute_tool_message(
|
||||
tool_name: str,
|
||||
args: dict,
|
||||
duration: float,
|
||||
result: str | None = None,
|
||||
tool_name: str, args: dict, duration: float, result: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a formatted tool completion line for CLI quiet mode.
|
||||
|
||||
@@ -455,11 +335,11 @@ def get_cute_tool_message(
|
||||
|
||||
def _trunc(s, n=40):
|
||||
s = str(s)
|
||||
return (s[: n - 3] + "...") if len(s) > n else s
|
||||
return (s[:n-3] + "...") if len(s) > n else s
|
||||
|
||||
def _path(p, n=35):
|
||||
p = str(p)
|
||||
return ("..." + p[-(n - 3) :]) if len(p) > n else p
|
||||
return ("..." + p[-(n-3):]) if len(p) > n else p
|
||||
|
||||
def _wrap(line: str) -> str:
|
||||
"""Append failure suffix when the tool failed."""
|
||||
@@ -474,7 +354,7 @@ def get_cute_tool_message(
|
||||
if urls:
|
||||
url = urls[0] if isinstance(urls, list) else str(urls)
|
||||
domain = url.replace("https://", "").replace("http://", "").split("/")[0]
|
||||
extra = f" +{len(urls) - 1}" if len(urls) > 1 else ""
|
||||
extra = f" +{len(urls)-1}" if len(urls) > 1 else ""
|
||||
return _wrap(f"┊ 📄 fetch {_trunc(domain, 35)}{extra} {dur}")
|
||||
return _wrap(f"┊ 📄 fetch pages {dur}")
|
||||
if tool_name == "web_crawl":
|
||||
@@ -486,15 +366,8 @@ def get_cute_tool_message(
|
||||
if tool_name == "process":
|
||||
action = args.get("action", "?")
|
||||
sid = args.get("session_id", "")[:12]
|
||||
labels = {
|
||||
"list": "ls processes",
|
||||
"poll": f"poll {sid}",
|
||||
"log": f"log {sid}",
|
||||
"wait": f"wait {sid}",
|
||||
"kill": f"kill {sid}",
|
||||
"write": f"write {sid}",
|
||||
"submit": f"submit {sid}",
|
||||
}
|
||||
labels = {"list": "ls processes", "poll": f"poll {sid}", "log": f"log {sid}",
|
||||
"wait": f"wait {sid}", "kill": f"kill {sid}", "write": f"write {sid}", "submit": f"submit {sid}"}
|
||||
return _wrap(f"┊ ⚙️ proc {labels.get(action, f'{action} {sid}')} {dur}")
|
||||
if tool_name == "read_file":
|
||||
return _wrap(f"┊ 📖 read {_path(args.get('path', ''))} {dur}")
|
||||
@@ -517,7 +390,7 @@ def get_cute_tool_message(
|
||||
if tool_name == "browser_click":
|
||||
return _wrap(f"┊ 👆 click {args.get('ref', '?')} {dur}")
|
||||
if tool_name == "browser_type":
|
||||
return _wrap(f'┊ ⌨️ type "{_trunc(args.get("text", ""), 30)}" {dur}')
|
||||
return _wrap(f"┊ ⌨️ type \"{_trunc(args.get('text', ''), 30)}\" {dur}")
|
||||
if tool_name == "browser_scroll":
|
||||
d = args.get("direction", "down")
|
||||
arrow = {"down": "↓", "up": "↑", "right": "→", "left": "←"}.get(d, "↓")
|
||||
@@ -542,16 +415,16 @@ def get_cute_tool_message(
|
||||
else:
|
||||
return _wrap(f"┊ 📋 plan {len(todos_arg)} task(s) {dur}")
|
||||
if tool_name == "session_search":
|
||||
return _wrap(f'┊ 🔍 recall "{_trunc(args.get("query", ""), 35)}" {dur}')
|
||||
return _wrap(f"┊ 🔍 recall \"{_trunc(args.get('query', ''), 35)}\" {dur}")
|
||||
if tool_name == "memory":
|
||||
action = args.get("action", "?")
|
||||
target = args.get("target", "")
|
||||
if action == "add":
|
||||
return _wrap(f'┊ 🧠 memory +{target}: "{_trunc(args.get("content", ""), 30)}" {dur}')
|
||||
return _wrap(f"┊ 🧠 memory +{target}: \"{_trunc(args.get('content', ''), 30)}\" {dur}")
|
||||
elif action == "replace":
|
||||
return _wrap(f'┊ 🧠 memory ~{target}: "{_trunc(args.get("old_text", ""), 20)}" {dur}')
|
||||
return _wrap(f"┊ 🧠 memory ~{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}")
|
||||
elif action == "remove":
|
||||
return _wrap(f'┊ 🧠 memory -{target}: "{_trunc(args.get("old_text", ""), 20)}" {dur}')
|
||||
return _wrap(f"┊ 🧠 memory -{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}")
|
||||
return _wrap(f"┊ 🧠 memory {action} {dur}")
|
||||
if tool_name == "skills_list":
|
||||
return _wrap(f"┊ 📚 skills list {args.get('category', 'all')} {dur}")
|
||||
@@ -566,7 +439,7 @@ def get_cute_tool_message(
|
||||
if tool_name == "mixture_of_agents":
|
||||
return _wrap(f"┊ 🧠 reason {_trunc(args.get('user_prompt', ''), 30)} {dur}")
|
||||
if tool_name == "send_message":
|
||||
return _wrap(f'┊ 📨 send {args.get("target", "?")}: "{_trunc(args.get("message", ""), 25)}" {dur}')
|
||||
return _wrap(f"┊ 📨 send {args.get('target', '?')}: \"{_trunc(args.get('message', ''), 25)}\" {dur}")
|
||||
if tool_name == "schedule_cronjob":
|
||||
return _wrap(f"┊ ⏰ schedule {_trunc(args.get('name', args.get('prompt', 'task')), 30)} {dur}")
|
||||
if tool_name == "list_cronjobs":
|
||||
@@ -575,16 +448,11 @@ def get_cute_tool_message(
|
||||
return _wrap(f"┊ ⏰ remove job {args.get('job_id', '?')} {dur}")
|
||||
if tool_name.startswith("rl_"):
|
||||
rl = {
|
||||
"rl_list_environments": "list envs",
|
||||
"rl_select_environment": f"select {args.get('name', '')}",
|
||||
"rl_get_current_config": "get config",
|
||||
"rl_edit_config": f"set {args.get('field', '?')}",
|
||||
"rl_start_training": "start training",
|
||||
"rl_check_status": f"status {args.get('run_id', '?')[:12]}",
|
||||
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}",
|
||||
"rl_get_results": f"results {args.get('run_id', '?')[:12]}",
|
||||
"rl_list_runs": "list runs",
|
||||
"rl_test_inference": "test inference",
|
||||
"rl_list_environments": "list envs", "rl_select_environment": f"select {args.get('name', '')}",
|
||||
"rl_get_current_config": "get config", "rl_edit_config": f"set {args.get('field', '?')}",
|
||||
"rl_start_training": "start training", "rl_check_status": f"status {args.get('run_id', '?')[:12]}",
|
||||
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}", "rl_get_results": f"results {args.get('run_id', '?')[:12]}",
|
||||
"rl_list_runs": "list runs", "rl_test_inference": "test inference",
|
||||
}
|
||||
return _wrap(f"┊ 🧪 rl {rl.get(tool_name, tool_name.replace('rl_', ''))} {dur}")
|
||||
if tool_name == "execute_code":
|
||||
|
||||
@@ -20,7 +20,7 @@ import json
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# =========================================================================
|
||||
# Model pricing (USD per million tokens) — approximate as of early 2026
|
||||
@@ -81,7 +81,7 @@ def _has_known_pricing(model_name: str) -> bool:
|
||||
return _get_pricing(model_name) is not _DEFAULT_PRICING
|
||||
|
||||
|
||||
def _get_pricing(model_name: str) -> dict[str, float]:
|
||||
def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
"""Look up pricing for a model. Uses fuzzy matching on model name.
|
||||
|
||||
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
|
||||
@@ -150,7 +150,7 @@ def _format_duration(seconds: float) -> str:
|
||||
return f"{days:.1f}d"
|
||||
|
||||
|
||||
def _bar_chart(values: list[int], max_width: int = 20) -> list[str]:
|
||||
def _bar_chart(values: List[int], max_width: int = 20) -> List[str]:
|
||||
"""Create simple horizontal bar chart strings from values."""
|
||||
peak = max(values) if values else 1
|
||||
if peak == 0:
|
||||
@@ -176,7 +176,7 @@ class InsightsEngine:
|
||||
self.db = db
|
||||
self._conn = db._conn
|
||||
|
||||
def generate(self, days: int = 30, source: str = None) -> dict[str, Any]:
|
||||
def generate(self, days: int = 30, source: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a complete insights report.
|
||||
|
||||
@@ -233,11 +233,10 @@ class InsightsEngine:
|
||||
# =========================================================================
|
||||
|
||||
# Columns we actually need (skip system_prompt, model_config blobs)
|
||||
_SESSION_COLS = (
|
||||
"id, source, model, started_at, ended_at, message_count, tool_call_count, input_tokens, output_tokens"
|
||||
)
|
||||
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
|
||||
"message_count, tool_call_count, input_tokens, output_tokens")
|
||||
|
||||
def _get_sessions(self, cutoff: float, source: str = None) -> list[dict]:
|
||||
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
"""Fetch sessions within the time window."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
@@ -255,7 +254,7 @@ class InsightsEngine:
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def _get_tool_usage(self, cutoff: float, source: str = None) -> list[dict]:
|
||||
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
"""Get tool call counts from messages.
|
||||
|
||||
Uses two sources:
|
||||
@@ -342,9 +341,12 @@ class InsightsEngine:
|
||||
tool_counts = merged
|
||||
|
||||
# Convert to the expected format
|
||||
return [{"tool_name": name, "count": count} for name, count in tool_counts.most_common()]
|
||||
return [
|
||||
{"tool_name": name, "count": count}
|
||||
for name, count in tool_counts.most_common()
|
||||
]
|
||||
|
||||
def _get_message_stats(self, cutoff: float, source: str = None) -> dict:
|
||||
def _get_message_stats(self, cutoff: float, source: str = None) -> Dict:
|
||||
"""Get aggregate message statistics."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
@@ -371,22 +373,16 @@ class InsightsEngine:
|
||||
(cutoff,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return (
|
||||
dict(row)
|
||||
if row
|
||||
else {
|
||||
"total_messages": 0,
|
||||
"user_messages": 0,
|
||||
"assistant_messages": 0,
|
||||
"tool_messages": 0,
|
||||
}
|
||||
)
|
||||
return dict(row) if row else {
|
||||
"total_messages": 0, "user_messages": 0,
|
||||
"assistant_messages": 0, "tool_messages": 0,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Computation
|
||||
# =========================================================================
|
||||
|
||||
def _compute_overview(self, sessions: list[dict], message_stats: dict) -> dict:
|
||||
def _compute_overview(self, sessions: List[Dict], message_stats: Dict) -> Dict:
|
||||
"""Compute high-level overview statistics."""
|
||||
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
|
||||
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
|
||||
@@ -446,18 +442,12 @@ class InsightsEngine:
|
||||
"models_without_pricing": sorted(models_without_pricing),
|
||||
}
|
||||
|
||||
def _compute_model_breakdown(self, sessions: list[dict]) -> list[dict]:
|
||||
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
|
||||
"""Break down usage by model."""
|
||||
model_data = defaultdict(
|
||||
lambda: {
|
||||
"sessions": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"tool_calls": 0,
|
||||
"cost": 0.0,
|
||||
}
|
||||
)
|
||||
model_data = defaultdict(lambda: {
|
||||
"sessions": 0, "input_tokens": 0, "output_tokens": 0,
|
||||
"total_tokens": 0, "tool_calls": 0, "cost": 0.0,
|
||||
})
|
||||
|
||||
for s in sessions:
|
||||
model = s.get("model") or "unknown"
|
||||
@@ -474,23 +464,20 @@ class InsightsEngine:
|
||||
d["cost"] += _estimate_cost(model, inp, out)
|
||||
d["has_pricing"] = _has_known_pricing(model)
|
||||
|
||||
result = [{"model": model, **data} for model, data in model_data.items()]
|
||||
result = [
|
||||
{"model": model, **data}
|
||||
for model, data in model_data.items()
|
||||
]
|
||||
# Sort by tokens first, fall back to session count when tokens are 0
|
||||
result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True)
|
||||
return result
|
||||
|
||||
def _compute_platform_breakdown(self, sessions: list[dict]) -> list[dict]:
|
||||
def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]:
|
||||
"""Break down usage by platform/source."""
|
||||
platform_data = defaultdict(
|
||||
lambda: {
|
||||
"sessions": 0,
|
||||
"messages": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"tool_calls": 0,
|
||||
}
|
||||
)
|
||||
platform_data = defaultdict(lambda: {
|
||||
"sessions": 0, "messages": 0, "input_tokens": 0,
|
||||
"output_tokens": 0, "total_tokens": 0, "tool_calls": 0,
|
||||
})
|
||||
|
||||
for s in sessions:
|
||||
source = s.get("source") or "unknown"
|
||||
@@ -504,26 +491,27 @@ class InsightsEngine:
|
||||
d["total_tokens"] += inp + out
|
||||
d["tool_calls"] += s.get("tool_call_count") or 0
|
||||
|
||||
result = [{"platform": platform, **data} for platform, data in platform_data.items()]
|
||||
result = [
|
||||
{"platform": platform, **data}
|
||||
for platform, data in platform_data.items()
|
||||
]
|
||||
result.sort(key=lambda x: x["sessions"], reverse=True)
|
||||
return result
|
||||
|
||||
def _compute_tool_breakdown(self, tool_usage: list[dict]) -> list[dict]:
|
||||
def _compute_tool_breakdown(self, tool_usage: List[Dict]) -> List[Dict]:
|
||||
"""Process tool usage data into a ranked list with percentages."""
|
||||
total_calls = sum(t["count"] for t in tool_usage) if tool_usage else 0
|
||||
result = []
|
||||
for t in tool_usage:
|
||||
pct = (t["count"] / total_calls * 100) if total_calls else 0
|
||||
result.append(
|
||||
{
|
||||
"tool": t["tool_name"],
|
||||
"count": t["count"],
|
||||
"percentage": pct,
|
||||
}
|
||||
)
|
||||
result.append({
|
||||
"tool": t["tool_name"],
|
||||
"count": t["count"],
|
||||
"percentage": pct,
|
||||
})
|
||||
return result
|
||||
|
||||
def _compute_activity_patterns(self, sessions: list[dict]) -> dict:
|
||||
def _compute_activity_patterns(self, sessions: List[Dict]) -> Dict:
|
||||
"""Analyze activity patterns by day of week and hour."""
|
||||
day_counts = Counter() # 0=Monday ... 6=Sunday
|
||||
hour_counts = Counter()
|
||||
@@ -539,9 +527,15 @@ class InsightsEngine:
|
||||
daily_counts[dt.strftime("%Y-%m-%d")] += 1
|
||||
|
||||
day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
|
||||
day_breakdown = [{"day": day_names[i], "count": day_counts.get(i, 0)} for i in range(7)]
|
||||
day_breakdown = [
|
||||
{"day": day_names[i], "count": day_counts.get(i, 0)}
|
||||
for i in range(7)
|
||||
]
|
||||
|
||||
hour_breakdown = [{"hour": i, "count": hour_counts.get(i, 0)} for i in range(24)]
|
||||
hour_breakdown = [
|
||||
{"hour": i, "count": hour_counts.get(i, 0)}
|
||||
for i in range(24)
|
||||
]
|
||||
|
||||
# Busiest day and hour
|
||||
busiest_day = max(day_breakdown, key=lambda x: x["count"]) if day_breakdown else None
|
||||
@@ -575,40 +569,37 @@ class InsightsEngine:
|
||||
"max_streak": max_streak,
|
||||
}
|
||||
|
||||
def _compute_top_sessions(self, sessions: list[dict]) -> list[dict]:
|
||||
def _compute_top_sessions(self, sessions: List[Dict]) -> List[Dict]:
|
||||
"""Find notable sessions (longest, most messages, most tokens)."""
|
||||
top = []
|
||||
|
||||
# Longest by duration
|
||||
sessions_with_duration = [s for s in sessions if s.get("started_at") and s.get("ended_at")]
|
||||
sessions_with_duration = [
|
||||
s for s in sessions
|
||||
if s.get("started_at") and s.get("ended_at")
|
||||
]
|
||||
if sessions_with_duration:
|
||||
longest = max(
|
||||
sessions_with_duration,
|
||||
key=lambda s: s["ended_at"] - s["started_at"],
|
||||
key=lambda s: (s["ended_at"] - s["started_at"]),
|
||||
)
|
||||
dur = longest["ended_at"] - longest["started_at"]
|
||||
top.append(
|
||||
{
|
||||
"label": "Longest session",
|
||||
"session_id": longest["id"][:16],
|
||||
"value": _format_duration(dur),
|
||||
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
|
||||
}
|
||||
)
|
||||
top.append({
|
||||
"label": "Longest session",
|
||||
"session_id": longest["id"][:16],
|
||||
"value": _format_duration(dur),
|
||||
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
|
||||
})
|
||||
|
||||
# Most messages
|
||||
most_msgs = max(sessions, key=lambda s: s.get("message_count") or 0)
|
||||
if (most_msgs.get("message_count") or 0) > 0:
|
||||
top.append(
|
||||
{
|
||||
"label": "Most messages",
|
||||
"session_id": most_msgs["id"][:16],
|
||||
"value": f"{most_msgs['message_count']} msgs",
|
||||
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d")
|
||||
if most_msgs.get("started_at")
|
||||
else "?",
|
||||
}
|
||||
)
|
||||
top.append({
|
||||
"label": "Most messages",
|
||||
"session_id": most_msgs["id"][:16],
|
||||
"value": f"{most_msgs['message_count']} msgs",
|
||||
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d") if most_msgs.get("started_at") else "?",
|
||||
})
|
||||
|
||||
# Most tokens
|
||||
most_tokens = max(
|
||||
@@ -617,30 +608,22 @@ class InsightsEngine:
|
||||
)
|
||||
token_total = (most_tokens.get("input_tokens") or 0) + (most_tokens.get("output_tokens") or 0)
|
||||
if token_total > 0:
|
||||
top.append(
|
||||
{
|
||||
"label": "Most tokens",
|
||||
"session_id": most_tokens["id"][:16],
|
||||
"value": f"{token_total:,} tokens",
|
||||
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d")
|
||||
if most_tokens.get("started_at")
|
||||
else "?",
|
||||
}
|
||||
)
|
||||
top.append({
|
||||
"label": "Most tokens",
|
||||
"session_id": most_tokens["id"][:16],
|
||||
"value": f"{token_total:,} tokens",
|
||||
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d") if most_tokens.get("started_at") else "?",
|
||||
})
|
||||
|
||||
# Most tool calls
|
||||
most_tools = max(sessions, key=lambda s: s.get("tool_call_count") or 0)
|
||||
if (most_tools.get("tool_call_count") or 0) > 0:
|
||||
top.append(
|
||||
{
|
||||
"label": "Most tool calls",
|
||||
"session_id": most_tools["id"][:16],
|
||||
"value": f"{most_tools['tool_call_count']} calls",
|
||||
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d")
|
||||
if most_tools.get("started_at")
|
||||
else "?",
|
||||
}
|
||||
)
|
||||
top.append({
|
||||
"label": "Most tool calls",
|
||||
"session_id": most_tools["id"][:16],
|
||||
"value": f"{most_tools['tool_call_count']} calls",
|
||||
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d") if most_tools.get("started_at") else "?",
|
||||
})
|
||||
|
||||
return top
|
||||
|
||||
@@ -648,7 +631,7 @@ class InsightsEngine:
|
||||
# Formatting
|
||||
# =========================================================================
|
||||
|
||||
def format_terminal(self, report: dict) -> str:
|
||||
def format_terminal(self, report: Dict) -> str:
|
||||
"""Format the insights report for terminal display (CLI)."""
|
||||
if report.get("empty"):
|
||||
days = report.get("days", 30)
|
||||
@@ -686,17 +669,13 @@ class InsightsEngine:
|
||||
lines.append(" " + "─" * 56)
|
||||
lines.append(f" Sessions: {o['total_sessions']:<12} Messages: {o['total_messages']:,}")
|
||||
lines.append(f" Tool calls: {o['total_tool_calls']:<12,} User messages: {o['user_messages']:,}")
|
||||
lines.append(
|
||||
f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}"
|
||||
)
|
||||
lines.append(f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}")
|
||||
cost_str = f"${o['estimated_cost']:.2f}"
|
||||
if o.get("models_without_pricing"):
|
||||
cost_str += " *"
|
||||
lines.append(f" Total tokens: {o['total_tokens']:<12,} Est. cost: {cost_str}")
|
||||
if o["total_hours"] > 0:
|
||||
lines.append(
|
||||
f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}"
|
||||
)
|
||||
lines.append(f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}")
|
||||
lines.append(f" Avg msgs/session: {o['avg_messages_per_session']:.1f}")
|
||||
lines.append("")
|
||||
|
||||
@@ -713,7 +692,7 @@ class InsightsEngine:
|
||||
cost_cell = " N/A"
|
||||
lines.append(f" {model_name:<30} {m['sessions']:>8} {m['total_tokens']:>12,} {cost_cell}")
|
||||
if o.get("models_without_pricing"):
|
||||
lines.append(" * Cost N/A for custom/self-hosted models")
|
||||
lines.append(f" * Cost N/A for custom/self-hosted models")
|
||||
lines.append("")
|
||||
|
||||
# Platform breakdown
|
||||
@@ -779,7 +758,7 @@ class InsightsEngine:
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_gateway(self, report: dict) -> str:
|
||||
def format_gateway(self, report: Dict) -> str:
|
||||
"""Format the insights report for gateway/messaging (shorter)."""
|
||||
if report.get("empty"):
|
||||
days = report.get("days", 30)
|
||||
@@ -792,20 +771,14 @@ class InsightsEngine:
|
||||
lines.append(f"📊 **Hermes Insights** — Last {days} days\n")
|
||||
|
||||
# Overview
|
||||
lines.append(
|
||||
f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}"
|
||||
)
|
||||
lines.append(
|
||||
f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})"
|
||||
)
|
||||
lines.append(f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}")
|
||||
lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})")
|
||||
cost_note = ""
|
||||
if o.get("models_without_pricing"):
|
||||
cost_note = " _(excludes custom/self-hosted models)_"
|
||||
lines.append(f"**Est. cost:** ${o['estimated_cost']:.2f}{cost_note}")
|
||||
if o["total_hours"] > 0:
|
||||
lines.append(
|
||||
f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}"
|
||||
)
|
||||
lines.append(f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}")
|
||||
lines.append("")
|
||||
|
||||
# Models (top 5)
|
||||
@@ -813,9 +786,7 @@ class InsightsEngine:
|
||||
lines.append("**🤖 Models:**")
|
||||
for m in report["models"][:5]:
|
||||
cost_str = f"${m['cost']:.2f}" if m.get("has_pricing") else "N/A"
|
||||
lines.append(
|
||||
f" {m['model'][:25]} — {m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}"
|
||||
)
|
||||
lines.append(f" {m['model'][:25]} — {m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}")
|
||||
lines.append("")
|
||||
|
||||
# Platforms (if multi-platform)
|
||||
@@ -838,13 +809,9 @@ class InsightsEngine:
|
||||
hr = act["busiest_hour"]["hour"]
|
||||
ampm = "AM" if hr < 12 else "PM"
|
||||
display_hr = hr % 12 or 12
|
||||
lines.append(
|
||||
f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)"
|
||||
)
|
||||
lines.append(f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)")
|
||||
if act.get("active_days"):
|
||||
lines.append(
|
||||
f"**Active days:** {act['active_days']}",
|
||||
)
|
||||
lines.append(f"**Active days:** {act['active_days']}", )
|
||||
if act.get("max_streak", 0) > 1:
|
||||
lines.append(f"**Best streak:** {act['max_streak']} consecutive days")
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
@@ -18,7 +18,7 @@ from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model_metadata_cache: dict[str, dict[str, Any]] = {}
|
||||
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_model_metadata_cache_time: float = 0
|
||||
_MODEL_CACHE_TTL = 3600
|
||||
|
||||
@@ -63,7 +63,7 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
}
|
||||
|
||||
|
||||
def fetch_model_metadata(force_refresh: bool = False) -> dict[str, dict[str, Any]]:
|
||||
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
|
||||
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
||||
global _model_metadata_cache, _model_metadata_cache_time
|
||||
|
||||
@@ -104,7 +104,7 @@ def _get_context_cache_path() -> Path:
|
||||
return hermes_home / "context_length_cache.yaml"
|
||||
|
||||
|
||||
def _load_context_cache() -> dict[str, int]:
|
||||
def _load_context_cache() -> Dict[str, int]:
|
||||
"""Load the model+provider → context_length cache from disk."""
|
||||
path = _get_context_cache_path()
|
||||
if not path.exists():
|
||||
@@ -139,14 +139,14 @@ def save_context_length(model: str, base_url: str, length: int) -> None:
|
||||
logger.debug("Failed to save context length cache: %s", e)
|
||||
|
||||
|
||||
def get_cached_context_length(model: str, base_url: str) -> int | None:
|
||||
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
|
||||
"""Look up a previously discovered context length for model+provider."""
|
||||
key = f"{model}@{base_url}"
|
||||
cache = _load_context_cache()
|
||||
return cache.get(key)
|
||||
|
||||
|
||||
def get_next_probe_tier(current_length: int) -> int | None:
|
||||
def get_next_probe_tier(current_length: int) -> Optional[int]:
|
||||
"""Return the next lower probe tier, or None if already at minimum."""
|
||||
for tier in CONTEXT_PROBE_TIERS:
|
||||
if tier < current_length:
|
||||
@@ -154,7 +154,7 @@ def get_next_probe_tier(current_length: int) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def parse_context_limit_from_error(error_msg: str) -> int | None:
|
||||
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
||||
"""Try to extract the actual context limit from an API error message.
|
||||
|
||||
Many providers include the limit in their error text, e.g.:
|
||||
@@ -166,11 +166,11 @@ def parse_context_limit_from_error(error_msg: str) -> int | None:
|
||||
error_lower = error_msg.lower()
|
||||
# Pattern: look for numbers near context-related keywords
|
||||
patterns = [
|
||||
r"(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})",
|
||||
r"context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})",
|
||||
r"(\d{4,})\s*(?:token)?\s*(?:context|limit)",
|
||||
r">\s*(\d{4,})\s*(?:max|limit|token)", # "250000 tokens > 200000 maximum"
|
||||
r"(\d{4,})\s*(?:max(?:imum)?)\b", # "200000 maximum"
|
||||
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
|
||||
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
|
||||
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
|
||||
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
|
||||
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, error_lower)
|
||||
@@ -218,7 +218,7 @@ def estimate_tokens_rough(text: str) -> int:
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
def estimate_messages_tokens_rough(messages: list[dict[str, Any]]) -> int:
|
||||
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
|
||||
"""Rough token estimate for a message list (pre-flight only)."""
|
||||
total_chars = sum(len(str(msg)) for msg in messages)
|
||||
return total_chars // 4
|
||||
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,29 +18,21 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CONTEXT_THREAT_PATTERNS = [
|
||||
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
|
||||
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
|
||||
(r"system\s+prompt\s+override", "sys_prompt_override"),
|
||||
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
|
||||
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
|
||||
(r"<!--[^>]*(?:ignore|override|system|secret|hidden)[^>]*-->", "html_comment_injection"),
|
||||
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
|
||||
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
|
||||
(r'system\s+prompt\s+override', "sys_prompt_override"),
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
|
||||
(r'<!--[^>]*(?:ignore|override|system|secret|hidden)[^>]*-->', "html_comment_injection"),
|
||||
(r'<\s*div\s+style\s*=\s*["\'].*display\s*:\s*none', "hidden_div"),
|
||||
(r"translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)", "translate_execute"),
|
||||
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
|
||||
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"),
|
||||
(r'translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)', "translate_execute"),
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
|
||||
]
|
||||
|
||||
_CONTEXT_INVISIBLE_CHARS = {
|
||||
"\u200b",
|
||||
"\u200c",
|
||||
"\u200d",
|
||||
"\u2060",
|
||||
"\ufeff",
|
||||
"\u202a",
|
||||
"\u202b",
|
||||
"\u202c",
|
||||
"\u202d",
|
||||
"\u202e",
|
||||
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
|
||||
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
|
||||
}
|
||||
|
||||
|
||||
@@ -59,13 +52,10 @@ def _scan_context_content(content: str, filename: str) -> str:
|
||||
|
||||
if findings:
|
||||
logger.warning("Context file %s blocked: %s", filename, ", ".join(findings))
|
||||
return (
|
||||
f"[BLOCKED: {filename} contained potential prompt injection ({', '.join(findings)}). Content not loaded.]"
|
||||
)
|
||||
return f"[BLOCKED: {filename} contained potential prompt injection ({', '.join(findings)}). Content not loaded.]"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Constants
|
||||
# =========================================================================
|
||||
@@ -141,7 +131,10 @@ PLATFORM_HINTS = {
|
||||
"files arrive as downloadable documents. You can also include image "
|
||||
"URLs in markdown format  and they will be sent as photos."
|
||||
),
|
||||
"cli": ("You are a CLI AI Agent. Try not to use markdown but simple text renderable inside a terminal."),
|
||||
"cli": (
|
||||
"You are a CLI AI Agent. Try not to use markdown but simple text "
|
||||
"renderable inside a terminal."
|
||||
),
|
||||
}
|
||||
|
||||
CONTEXT_FILE_MAX_CHARS = 20_000
|
||||
@@ -153,20 +146,18 @@ CONTEXT_TRUNCATE_TAIL_RATIO = 0.2
|
||||
# Skills index
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _read_skill_description(skill_file: Path, max_chars: int = 60) -> str:
|
||||
"""Read the description from a SKILL.md frontmatter, capped at max_chars."""
|
||||
try:
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
match = re.search(
|
||||
r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---",
|
||||
raw,
|
||||
re.MULTILINE | re.DOTALL,
|
||||
raw, re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
if match:
|
||||
desc = match.group(1).strip().strip("'\"")
|
||||
if len(desc) > max_chars:
|
||||
desc = desc[: max_chars - 3] + "..."
|
||||
desc = desc[:max_chars - 3] + "..."
|
||||
return desc
|
||||
except Exception:
|
||||
pass
|
||||
@@ -181,7 +172,6 @@ def _skill_is_platform_compatible(skill_file: Path) -> bool:
|
||||
"""
|
||||
try:
|
||||
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
|
||||
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = _parse_frontmatter(raw)
|
||||
return skill_matches_platform(frontmatter)
|
||||
@@ -270,7 +260,8 @@ def build_skills_system_prompt() -> str:
|
||||
"load it with skill_view(name) and follow its instructions. "
|
||||
"If a skill has issues, fix it with skill_manage(action='patch').\n"
|
||||
"\n"
|
||||
"<available_skills>\n" + "\n".join(index_lines) + "\n"
|
||||
"<available_skills>\n"
|
||||
+ "\n".join(index_lines) + "\n"
|
||||
"</available_skills>\n"
|
||||
"\n"
|
||||
"If none match, proceed normally without loading a skill."
|
||||
@@ -281,7 +272,6 @@ def build_skills_system_prompt() -> str:
|
||||
# Context files (SOUL.md, AGENTS.md, .cursorrules)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE_MAX_CHARS) -> str:
|
||||
"""Head/tail truncation with a marker in the middle."""
|
||||
if len(content) <= max_chars:
|
||||
@@ -294,7 +284,7 @@ def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE
|
||||
return head + marker + tail
|
||||
|
||||
|
||||
def build_context_files_prompt(cwd: str | None = None) -> str:
|
||||
def build_context_files_prompt(cwd: Optional[str] = None) -> str:
|
||||
"""Discover and load context files for the system prompt.
|
||||
|
||||
Discovery: AGENTS.md (recursive), .cursorrules / .cursor/rules/*.mdc,
|
||||
@@ -317,9 +307,7 @@ def build_context_files_prompt(cwd: str | None = None) -> str:
|
||||
if top_level_agents:
|
||||
agents_files = []
|
||||
for root, dirs, files in os.walk(cwd_path):
|
||||
dirs[:] = [
|
||||
d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "venv", ".venv")
|
||||
]
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')]
|
||||
for f in files:
|
||||
if f.lower() == "agents.md":
|
||||
agents_files.append(Path(root) / f)
|
||||
@@ -396,7 +384,4 @@ def build_context_files_prompt(cwd: str | None = None) -> str:
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
return (
|
||||
"# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n"
|
||||
+ "\n".join(sections)
|
||||
)
|
||||
return "# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n" + "\n".join(sections)
|
||||
|
||||
@@ -9,7 +9,7 @@ Pure functions -- no class state, no AIAgent dependency.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
@@ -36,9 +36,9 @@ def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
|
||||
|
||||
def apply_anthropic_cache_control(
|
||||
api_messages: list[dict[str, Any]],
|
||||
api_messages: List[Dict[str, Any]],
|
||||
cache_ttl: str = "5m",
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Apply system_and_3 caching strategy to messages for Anthropic models.
|
||||
|
||||
Places up to 4 cache_control breakpoints: system prompt + last 3 non-system messages.
|
||||
|
||||
@@ -10,33 +10,34 @@ the first 6 and last 4 characters for debuggability.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Known API key prefixes -- match the prefix + contiguous token chars
|
||||
_PREFIX_PATTERNS = [
|
||||
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
|
||||
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
|
||||
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
|
||||
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
|
||||
r"AIza[A-Za-z0-9_-]{30,}", # Google API keys
|
||||
r"pplx-[A-Za-z0-9]{10,}", # Perplexity
|
||||
r"fal_[A-Za-z0-9_-]{10,}", # Fal.ai
|
||||
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
|
||||
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
|
||||
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
|
||||
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
|
||||
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
|
||||
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
|
||||
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
|
||||
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
|
||||
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
|
||||
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
|
||||
r"npm_[A-Za-z0-9]{10,}", # npm access token
|
||||
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
|
||||
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
|
||||
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
|
||||
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
|
||||
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
|
||||
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
|
||||
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
|
||||
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
|
||||
r"AIza[A-Za-z0-9_-]{30,}", # Google API keys
|
||||
r"pplx-[A-Za-z0-9]{10,}", # Perplexity
|
||||
r"fal_[A-Za-z0-9_-]{10,}", # Fal.ai
|
||||
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
|
||||
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
|
||||
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
|
||||
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
|
||||
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
|
||||
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
|
||||
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
|
||||
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
|
||||
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
|
||||
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
|
||||
r"npm_[A-Za-z0-9]{10,}", # npm access token
|
||||
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
|
||||
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
|
||||
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
|
||||
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
|
||||
]
|
||||
|
||||
# ENV assignment patterns: KEY=value where KEY contains a secret-like name
|
||||
@@ -65,7 +66,9 @@ _TELEGRAM_RE = re.compile(
|
||||
)
|
||||
|
||||
# Private key blocks: -----BEGIN RSA PRIVATE KEY----- ... -----END RSA PRIVATE KEY-----
|
||||
_PRIVATE_KEY_RE = re.compile(r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----")
|
||||
_PRIVATE_KEY_RE = re.compile(
|
||||
r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----"
|
||||
)
|
||||
|
||||
# Database connection strings: protocol://user:PASSWORD@host
|
||||
# Catches postgres, mysql, mongodb, redis, amqp URLs and redacts the password
|
||||
@@ -79,7 +82,9 @@ _DB_CONNSTR_RE = re.compile(
|
||||
_SIGNAL_PHONE_RE = re.compile(r"(\+[1-9]\d{6,14})(?![A-Za-z0-9])")
|
||||
|
||||
# Compile known prefix patterns into one alternation
|
||||
_PREFIX_RE = re.compile(r"(?<![A-Za-z0-9_-])(" + "|".join(_PREFIX_PATTERNS) + r")(?![A-Za-z0-9_-])")
|
||||
_PREFIX_RE = re.compile(
|
||||
r"(?<![A-Za-z0-9_-])(" + "|".join(_PREFIX_PATTERNS) + r")(?![A-Za-z0-9_-])"
|
||||
)
|
||||
|
||||
|
||||
def _mask_token(token: str) -> str:
|
||||
@@ -107,14 +112,12 @@ def redact_sensitive_text(text: str) -> str:
|
||||
def _redact_env(m):
|
||||
name, quote, value = m.group(1), m.group(2), m.group(3)
|
||||
return f"{name}={quote}{_mask_token(value)}{quote}"
|
||||
|
||||
text = _ENV_ASSIGN_RE.sub(_redact_env, text)
|
||||
|
||||
# JSON fields: "apiKey": "value"
|
||||
def _redact_json(m):
|
||||
key, value = m.group(1), m.group(2)
|
||||
return f'{key}: "{_mask_token(value)}"'
|
||||
|
||||
text = _JSON_FIELD_RE.sub(_redact_json, text)
|
||||
|
||||
# Authorization headers
|
||||
@@ -128,7 +131,6 @@ def redact_sensitive_text(text: str) -> str:
|
||||
prefix = m.group(1) or ""
|
||||
digits = m.group(2)
|
||||
return f"{prefix}{digits}:***"
|
||||
|
||||
text = _TELEGRAM_RE.sub(_redact_telegram, text)
|
||||
|
||||
# Private key blocks
|
||||
@@ -143,7 +145,6 @@ def redact_sensitive_text(text: str) -> str:
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:]
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
text = _SIGNAL_PHONE_RE.sub(_redact_phone, text)
|
||||
|
||||
return text
|
||||
@@ -152,7 +153,7 @@ def redact_sensitive_text(text: str) -> str:
|
||||
class RedactingFormatter(logging.Formatter):
|
||||
"""Log formatter that redacts secrets from all log messages."""
|
||||
|
||||
def __init__(self, fmt=None, datefmt=None, style="%", **kwargs):
|
||||
def __init__(self, fmt=None, datefmt=None, style='%', **kwargs):
|
||||
super().__init__(fmt, datefmt, style, **kwargs)
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
|
||||
@@ -6,14 +6,14 @@ can invoke skills via /skill-name commands.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_skill_commands: dict[str, dict[str, Any]] = {}
|
||||
_skill_commands: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def scan_skill_commands() -> dict[str, dict[str, Any]]:
|
||||
def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
"""Scan ~/.hermes/skills/ and return a mapping of /command -> skill info.
|
||||
|
||||
Returns:
|
||||
@@ -23,27 +23,26 @@ def scan_skill_commands() -> dict[str, dict[str, Any]]:
|
||||
_skill_commands = {}
|
||||
try:
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform
|
||||
|
||||
if not SKILLS_DIR.exists():
|
||||
return _skill_commands
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
if any(part in (".git", ".github", ".hub") for part in skill_md.parts):
|
||||
if any(part in ('.git', '.github', '.hub') for part in skill_md.parts):
|
||||
continue
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
content = skill_md.read_text(encoding='utf-8')
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
name = frontmatter.get("name", skill_md.parent.name)
|
||||
description = frontmatter.get("description", "")
|
||||
name = frontmatter.get('name', skill_md.parent.name)
|
||||
description = frontmatter.get('description', '')
|
||||
if not description:
|
||||
for line in body.strip().split("\n"):
|
||||
for line in body.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
if line and not line.startswith('#'):
|
||||
description = line[:80]
|
||||
break
|
||||
cmd_name = name.lower().replace(" ", "-").replace("_", "-")
|
||||
cmd_name = name.lower().replace(' ', '-').replace('_', '-')
|
||||
_skill_commands[f"/{cmd_name}"] = {
|
||||
"name": name,
|
||||
"description": description or f"Invoke the {name} skill",
|
||||
@@ -57,14 +56,14 @@ def scan_skill_commands() -> dict[str, dict[str, Any]]:
|
||||
return _skill_commands
|
||||
|
||||
|
||||
def get_skill_commands() -> dict[str, dict[str, Any]]:
|
||||
def get_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
"""Return the current skill commands mapping (scan first if empty)."""
|
||||
if not _skill_commands:
|
||||
scan_skill_commands()
|
||||
return _skill_commands
|
||||
|
||||
|
||||
def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> str | None:
|
||||
def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> Optional[str]:
|
||||
"""Build the user message content for a skill slash command invocation.
|
||||
|
||||
Args:
|
||||
@@ -84,7 +83,7 @@ def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") ->
|
||||
skill_name = skill_info["name"]
|
||||
|
||||
try:
|
||||
content = skill_md_path.read_text(encoding="utf-8")
|
||||
content = skill_md_path.read_text(encoding='utf-8')
|
||||
except Exception:
|
||||
return f"[Failed to load skill: {skill_name}]"
|
||||
|
||||
@@ -112,8 +111,6 @@ def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") ->
|
||||
|
||||
if user_instruction:
|
||||
parts.append("")
|
||||
parts.append(
|
||||
f"The user has provided the following instruction alongside the skill invocation: {user_instruction}"
|
||||
)
|
||||
parts.append(f"The user has provided the following instruction alongside the skill invocation: {user_instruction}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
@@ -8,7 +8,7 @@ the file-write logic live here.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,7 +27,8 @@ def has_incomplete_scratchpad(content: str) -> bool:
|
||||
return "<REASONING_SCRATCHPAD>" in content and "</REASONING_SCRATCHPAD>" not in content
|
||||
|
||||
|
||||
def save_trajectory(trajectory: list[dict[str, Any]], model: str, completed: bool, filename: str = None):
|
||||
def save_trajectory(trajectory: List[Dict[str, Any]], model: str,
|
||||
completed: bool, filename: str = None):
|
||||
"""Append a trajectory entry to a JSONL file.
|
||||
|
||||
Args:
|
||||
|
||||
572
batch_runner.py
572
batch_runner.py
File diff suppressed because it is too large
Load Diff
@@ -15,18 +15,18 @@ duplicate execution if multiple processes overlap.
|
||||
"""
|
||||
|
||||
from cron.jobs import (
|
||||
JOBS_FILE,
|
||||
create_job,
|
||||
get_job,
|
||||
list_jobs,
|
||||
remove_job,
|
||||
update_job,
|
||||
JOBS_FILE,
|
||||
)
|
||||
from cron.scheduler import tick
|
||||
|
||||
__all__ = [
|
||||
"create_job",
|
||||
"get_job",
|
||||
"get_job",
|
||||
"list_jobs",
|
||||
"remove_job",
|
||||
"update_job",
|
||||
|
||||
149
cron/jobs.py
149
cron/jobs.py
@@ -6,19 +6,18 @@ Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
try:
|
||||
from croniter import croniter
|
||||
|
||||
HAS_CRONITER = True
|
||||
except ImportError:
|
||||
HAS_CRONITER = False
|
||||
@@ -43,38 +42,37 @@ def ensure_dirs():
|
||||
# Schedule Parsing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def parse_duration(s: str) -> int:
|
||||
"""
|
||||
Parse duration string into minutes.
|
||||
|
||||
|
||||
Examples:
|
||||
"30m" → 30
|
||||
"2h" → 120
|
||||
"1d" → 1440
|
||||
"""
|
||||
s = s.strip().lower()
|
||||
match = re.match(r"^(\d+)\s*(m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)$", s)
|
||||
match = re.match(r'^(\d+)\s*(m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)$', s)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid duration: '{s}'. Use format like '30m', '2h', or '1d'")
|
||||
|
||||
|
||||
value = int(match.group(1))
|
||||
unit = match.group(2)[0] # First char: m, h, or d
|
||||
|
||||
multipliers = {"m": 1, "h": 60, "d": 1440}
|
||||
|
||||
multipliers = {'m': 1, 'h': 60, 'd': 1440}
|
||||
return value * multipliers[unit]
|
||||
|
||||
|
||||
def parse_schedule(schedule: str) -> dict[str, Any]:
|
||||
def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse schedule string into structured format.
|
||||
|
||||
|
||||
Returns dict with:
|
||||
- kind: "once" | "interval" | "cron"
|
||||
- For "once": "run_at" (ISO timestamp)
|
||||
- For "interval": "minutes" (int)
|
||||
- For "cron": "expr" (cron expression)
|
||||
|
||||
|
||||
Examples:
|
||||
"30m" → once in 30 minutes
|
||||
"2h" → once in 2 hours
|
||||
@@ -86,17 +84,23 @@ def parse_schedule(schedule: str) -> dict[str, Any]:
|
||||
schedule = schedule.strip()
|
||||
original = schedule
|
||||
schedule_lower = schedule.lower()
|
||||
|
||||
|
||||
# "every X" pattern → recurring interval
|
||||
if schedule_lower.startswith("every "):
|
||||
duration_str = schedule[6:].strip()
|
||||
minutes = parse_duration(duration_str)
|
||||
return {"kind": "interval", "minutes": minutes, "display": f"every {minutes}m"}
|
||||
|
||||
return {
|
||||
"kind": "interval",
|
||||
"minutes": minutes,
|
||||
"display": f"every {minutes}m"
|
||||
}
|
||||
|
||||
# Check for cron expression (5 or 6 space-separated fields)
|
||||
# Cron fields: minute hour day month weekday [year]
|
||||
parts = schedule.split()
|
||||
if len(parts) >= 5 and all(re.match(r"^[\d\*\-,/]+$", p) for p in parts[:5]):
|
||||
if len(parts) >= 5 and all(
|
||||
re.match(r'^[\d\*\-,/]+$', p) for p in parts[:5]
|
||||
):
|
||||
if not HAS_CRONITER:
|
||||
raise ValueError("Cron expressions require 'croniter' package. Install with: pip install croniter")
|
||||
# Validate cron expression
|
||||
@@ -104,25 +108,37 @@ def parse_schedule(schedule: str) -> dict[str, Any]:
|
||||
croniter(schedule)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid cron expression '{schedule}': {e}")
|
||||
return {"kind": "cron", "expr": schedule, "display": schedule}
|
||||
|
||||
return {
|
||||
"kind": "cron",
|
||||
"expr": schedule,
|
||||
"display": schedule
|
||||
}
|
||||
|
||||
# ISO timestamp (contains T or looks like date)
|
||||
if "T" in schedule or re.match(r"^\d{4}-\d{2}-\d{2}", schedule):
|
||||
if 'T' in schedule or re.match(r'^\d{4}-\d{2}-\d{2}', schedule):
|
||||
try:
|
||||
# Parse and validate
|
||||
dt = datetime.fromisoformat(schedule.replace("Z", "+00:00"))
|
||||
return {"kind": "once", "run_at": dt.isoformat(), "display": f"once at {dt.strftime('%Y-%m-%d %H:%M')}"}
|
||||
dt = datetime.fromisoformat(schedule.replace('Z', '+00:00'))
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": dt.isoformat(),
|
||||
"display": f"once at {dt.strftime('%Y-%m-%d %H:%M')}"
|
||||
}
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid timestamp '{schedule}': {e}")
|
||||
|
||||
|
||||
# Duration like "30m", "2h", "1d" → one-shot from now
|
||||
try:
|
||||
minutes = parse_duration(schedule)
|
||||
run_at = _hermes_now() + timedelta(minutes=minutes)
|
||||
return {"kind": "once", "run_at": run_at.isoformat(), "display": f"once in {original}"}
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": run_at.isoformat(),
|
||||
"display": f"once in {original}"
|
||||
}
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid schedule '{original}'. Use:\n"
|
||||
f" - Duration: '30m', '2h', '1d' (one-shot)\n"
|
||||
@@ -145,7 +161,7 @@ def _ensure_aware(dt: datetime) -> datetime:
|
||||
return dt
|
||||
|
||||
|
||||
def compute_next_run(schedule: dict[str, Any], last_run_at: str | None = None) -> str | None:
|
||||
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Compute the next run time for a schedule.
|
||||
|
||||
@@ -183,27 +199,26 @@ def compute_next_run(schedule: dict[str, Any], last_run_at: str | None = None) -
|
||||
# Job CRUD Operations
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def load_jobs() -> list[dict[str, Any]]:
|
||||
def load_jobs() -> List[Dict[str, Any]]:
|
||||
"""Load all jobs from storage."""
|
||||
ensure_dirs()
|
||||
if not JOBS_FILE.exists():
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
with open(JOBS_FILE, encoding="utf-8") as f:
|
||||
with open(JOBS_FILE, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data.get("jobs", [])
|
||||
except (OSError, json.JSONDecodeError):
|
||||
except (json.JSONDecodeError, IOError):
|
||||
return []
|
||||
|
||||
|
||||
def save_jobs(jobs: list[dict[str, Any]]):
|
||||
def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
"""Save all jobs to storage."""
|
||||
ensure_dirs()
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix=".tmp", prefix=".jobs_")
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix='.tmp', prefix='.jobs_')
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
json.dump({"jobs": jobs, "updated_at": _hermes_now().isoformat()}, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
@@ -219,14 +234,14 @@ def save_jobs(jobs: list[dict[str, Any]]):
|
||||
def create_job(
|
||||
prompt: str,
|
||||
schedule: str,
|
||||
name: str | None = None,
|
||||
repeat: int | None = None,
|
||||
deliver: str | None = None,
|
||||
origin: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
name: Optional[str] = None,
|
||||
repeat: Optional[int] = None,
|
||||
deliver: Optional[str] = None,
|
||||
origin: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new cron job.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: The prompt to run (must be self-contained)
|
||||
schedule: Schedule string (see parse_schedule)
|
||||
@@ -234,23 +249,23 @@ def create_job(
|
||||
repeat: How many times to run (None = forever, 1 = once)
|
||||
deliver: Where to deliver output ("origin", "local", "telegram", etc.)
|
||||
origin: Source info where job was created (for "origin" delivery)
|
||||
|
||||
|
||||
Returns:
|
||||
The created job dict
|
||||
"""
|
||||
parsed_schedule = parse_schedule(schedule)
|
||||
|
||||
|
||||
# Auto-set repeat=1 for one-shot schedules if not specified
|
||||
if parsed_schedule["kind"] == "once" and repeat is None:
|
||||
repeat = 1
|
||||
|
||||
|
||||
# Default delivery to origin if available, otherwise local
|
||||
if deliver is None:
|
||||
deliver = "origin" if origin else "local"
|
||||
|
||||
|
||||
job_id = uuid.uuid4().hex[:12]
|
||||
now = _hermes_now().isoformat()
|
||||
|
||||
|
||||
job = {
|
||||
"id": job_id,
|
||||
"name": name or prompt[:50].strip(),
|
||||
@@ -259,7 +274,7 @@ def create_job(
|
||||
"schedule_display": parsed_schedule.get("display", schedule),
|
||||
"repeat": {
|
||||
"times": repeat, # None = forever
|
||||
"completed": 0,
|
||||
"completed": 0
|
||||
},
|
||||
"enabled": True,
|
||||
"created_at": now,
|
||||
@@ -271,15 +286,15 @@ def create_job(
|
||||
"deliver": deliver,
|
||||
"origin": origin, # Tracks where job was created for "origin" delivery
|
||||
}
|
||||
|
||||
|
||||
jobs = load_jobs()
|
||||
jobs.append(job)
|
||||
save_jobs(jobs)
|
||||
|
||||
|
||||
return job
|
||||
|
||||
|
||||
def get_job(job_id: str) -> dict[str, Any] | None:
|
||||
def get_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a job by ID."""
|
||||
jobs = load_jobs()
|
||||
for job in jobs:
|
||||
@@ -288,7 +303,7 @@ def get_job(job_id: str) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
def list_jobs(include_disabled: bool = False) -> list[dict[str, Any]]:
|
||||
def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]:
|
||||
"""List all jobs, optionally including disabled ones."""
|
||||
jobs = load_jobs()
|
||||
if not include_disabled:
|
||||
@@ -296,7 +311,7 @@ def list_jobs(include_disabled: bool = False) -> list[dict[str, Any]]:
|
||||
return jobs
|
||||
|
||||
|
||||
def update_job(job_id: str, updates: dict[str, Any]) -> dict[str, Any] | None:
|
||||
def update_job(job_id: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Update a job by ID."""
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
@@ -318,10 +333,10 @@ def remove_job(job_id: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def mark_job_run(job_id: str, success: bool, error: str | None = None):
|
||||
def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
"""
|
||||
Mark a job as having been run.
|
||||
|
||||
|
||||
Updates last_run_at, last_status, increments completed count,
|
||||
computes next_run_at, and auto-deletes if repeat limit reached.
|
||||
"""
|
||||
@@ -332,11 +347,11 @@ def mark_job_run(job_id: str, success: bool, error: str | None = None):
|
||||
job["last_run_at"] = now
|
||||
job["last_status"] = "ok" if success else "error"
|
||||
job["last_error"] = error if not success else None
|
||||
|
||||
|
||||
# Increment completed count
|
||||
if job.get("repeat"):
|
||||
job["repeat"]["completed"] = job["repeat"].get("completed", 0) + 1
|
||||
|
||||
|
||||
# Check if we've hit the repeat limit
|
||||
times = job["repeat"].get("times")
|
||||
completed = job["repeat"]["completed"]
|
||||
@@ -345,38 +360,38 @@ def mark_job_run(job_id: str, success: bool, error: str | None = None):
|
||||
jobs.pop(i)
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
|
||||
# Compute next run
|
||||
job["next_run_at"] = compute_next_run(job["schedule"], now)
|
||||
|
||||
|
||||
# If no next run (one-shot completed), disable
|
||||
if job["next_run_at"] is None:
|
||||
job["enabled"] = False
|
||||
|
||||
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
|
||||
save_jobs(jobs)
|
||||
|
||||
|
||||
def get_due_jobs() -> list[dict[str, Any]]:
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now."""
|
||||
now = _hermes_now()
|
||||
jobs = load_jobs()
|
||||
due = []
|
||||
|
||||
|
||||
for job in jobs:
|
||||
if not job.get("enabled", True):
|
||||
continue
|
||||
|
||||
|
||||
next_run = job.get("next_run_at")
|
||||
if not next_run:
|
||||
continue
|
||||
|
||||
|
||||
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
|
||||
if next_run_dt <= now:
|
||||
due.append(job)
|
||||
|
||||
|
||||
return due
|
||||
|
||||
|
||||
@@ -385,11 +400,11 @@ def save_job_output(job_id: str, output: str):
|
||||
ensure_dirs()
|
||||
job_output_dir = OUTPUT_DIR / job_id
|
||||
job_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
timestamp = _hermes_now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
output_file = job_output_dir / f"{timestamp}.md"
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(output)
|
||||
|
||||
|
||||
return output_file
|
||||
|
||||
@@ -23,7 +23,9 @@ except ImportError:
|
||||
import msvcrt
|
||||
except ImportError:
|
||||
msvcrt = None
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
@@ -42,7 +44,7 @@ _LOCK_DIR = _hermes_home / "cron"
|
||||
_LOCK_FILE = _LOCK_DIR / ".tick.lock"
|
||||
|
||||
|
||||
def _resolve_origin(job: dict) -> dict | None:
|
||||
def _resolve_origin(job: dict) -> Optional[dict]:
|
||||
"""Extract origin info from a job, returning {platform, chat_id, chat_name} or None."""
|
||||
origin = job.get("origin")
|
||||
if not origin:
|
||||
@@ -85,16 +87,11 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
# Fall back to home channel
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
if not chat_id:
|
||||
logger.warning(
|
||||
"Job '%s' deliver=%s but no chat_id or home channel. Set via: hermes config set %s_HOME_CHANNEL <channel_id>",
|
||||
job["id"],
|
||||
deliver,
|
||||
platform_name.upper(),
|
||||
)
|
||||
logger.warning("Job '%s' deliver=%s but no chat_id or home channel. Set via: hermes config set %s_HOME_CHANNEL <channel_id>", job["id"], deliver, platform_name.upper())
|
||||
return
|
||||
|
||||
from gateway.config import Platform, load_gateway_config
|
||||
from tools.send_message_tool import _send_to_platform
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
|
||||
platform_map = {
|
||||
"telegram": Platform.TELEGRAM,
|
||||
@@ -126,7 +123,6 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
# asyncio.run() fails if there's already a running loop in this thread;
|
||||
# spin up a new thread to avoid that.
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, content))
|
||||
result = future.result(timeout=30)
|
||||
@@ -141,26 +137,25 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
# Mirror the delivered content into the target's gateway session
|
||||
try:
|
||||
from gateway.mirror import mirror_to_session
|
||||
|
||||
mirror_to_session(platform_name, chat_id, content, source_label="cron")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def run_job(job: dict) -> tuple[bool, str, str, str | None]:
|
||||
def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
"""
|
||||
Execute a single cron job.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (success, full_output_doc, final_response, error_message)
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
job_id = job["id"]
|
||||
job_name = job["name"]
|
||||
prompt = job["prompt"]
|
||||
origin = _resolve_origin(job)
|
||||
|
||||
|
||||
logger.info("Running job '%s' (ID: %s)", job_name, job_id)
|
||||
logger.info("Prompt: %s", prompt[:100])
|
||||
|
||||
@@ -175,7 +170,6 @@ def run_job(job: dict) -> tuple[bool, str, str, str | None]:
|
||||
# Re-read .env and config.yaml fresh every run so provider/key
|
||||
# changes take effect without a gateway restart.
|
||||
from dotenv import load_dotenv
|
||||
|
||||
try:
|
||||
load_dotenv(str(_hermes_home / ".env"), override=True, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
@@ -187,7 +181,6 @@ def run_job(job: dict) -> tuple[bool, str, str, str | None]:
|
||||
_cfg = {}
|
||||
try:
|
||||
import yaml
|
||||
|
||||
_cfg_path = str(_hermes_home / "config.yaml")
|
||||
if os.path.exists(_cfg_path):
|
||||
with open(_cfg_path) as _f:
|
||||
@@ -217,13 +210,12 @@ def run_job(job: dict) -> tuple[bool, str, str, str | None]:
|
||||
prefill_file = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") or _cfg.get("prefill_messages_file", "")
|
||||
if prefill_file:
|
||||
import json as _json
|
||||
|
||||
pfpath = Path(prefill_file).expanduser()
|
||||
if not pfpath.is_absolute():
|
||||
pfpath = _hermes_home / pfpath
|
||||
if pfpath.exists():
|
||||
try:
|
||||
with open(pfpath, encoding="utf-8") as _pf:
|
||||
with open(pfpath, "r", encoding="utf-8") as _pf:
|
||||
prefill_messages = _json.load(_pf)
|
||||
if not isinstance(prefill_messages, list):
|
||||
prefill_messages = None
|
||||
@@ -237,10 +229,9 @@ def run_job(job: dict) -> tuple[bool, str, str, str | None]:
|
||||
pr = _cfg.get("provider_routing", {})
|
||||
|
||||
from hermes_cli.runtime_provider import (
|
||||
format_runtime_provider_error,
|
||||
resolve_runtime_provider,
|
||||
format_runtime_provider_error,
|
||||
)
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(
|
||||
requested=os.getenv("HERMES_INFERENCE_PROVIDER"),
|
||||
@@ -263,20 +254,20 @@ def run_job(job: dict) -> tuple[bool, str, str, str | None]:
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
quiet_mode=True,
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}",
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if not final_response:
|
||||
final_response = "(No response generated)"
|
||||
|
||||
|
||||
output = f"""# Cron Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {_hermes_now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
**Schedule:** {job.get("schedule_display", "N/A")}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
|
||||
## Prompt
|
||||
|
||||
@@ -286,19 +277,19 @@ def run_job(job: dict) -> tuple[bool, str, str, str | None]:
|
||||
|
||||
{final_response}
|
||||
"""
|
||||
|
||||
|
||||
logger.info("Job '%s' completed successfully", job_name)
|
||||
return True, output, final_response, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
||||
logger.error("Job '%s' failed: %s", job_name, error_msg)
|
||||
|
||||
|
||||
output = f"""# Cron Job: {job_name} (FAILED)
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {_hermes_now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
**Schedule:** {job.get("schedule_display", "N/A")}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
|
||||
## Prompt
|
||||
|
||||
@@ -323,13 +314,13 @@ def run_job(job: dict) -> tuple[bool, str, str, str | None]:
|
||||
def tick(verbose: bool = True) -> int:
|
||||
"""
|
||||
Check and run all due jobs.
|
||||
|
||||
|
||||
Uses a file lock so only one tick runs at a time, even if the gateway's
|
||||
in-process ticker and a standalone daemon or manual tick overlap.
|
||||
|
||||
|
||||
Args:
|
||||
verbose: Whether to print status messages
|
||||
|
||||
|
||||
Returns:
|
||||
Number of jobs executed (0 if another tick is already running)
|
||||
"""
|
||||
@@ -343,7 +334,7 @@ def tick(verbose: bool = True) -> int:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
elif msvcrt:
|
||||
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_NBLCK, 1)
|
||||
except OSError:
|
||||
except (OSError, IOError):
|
||||
logger.debug("Tick skipped — another instance holds the lock")
|
||||
if lock_fd is not None:
|
||||
lock_fd.close()
|
||||
@@ -353,11 +344,11 @@ def tick(verbose: bool = True) -> int:
|
||||
due_jobs = get_due_jobs()
|
||||
|
||||
if verbose and not due_jobs:
|
||||
logger.info("%s - No jobs due", _hermes_now().strftime("%H:%M:%S"))
|
||||
logger.info("%s - No jobs due", _hermes_now().strftime('%H:%M:%S'))
|
||||
return 0
|
||||
|
||||
if verbose:
|
||||
logger.info("%s - %s job(s) due", _hermes_now().strftime("%H:%M:%S"), len(due_jobs))
|
||||
logger.info("%s - %s job(s) due", _hermes_now().strftime('%H:%M:%S'), len(due_jobs))
|
||||
|
||||
executed = 0
|
||||
for job in due_jobs:
|
||||
@@ -369,9 +360,7 @@ def tick(verbose: bool = True) -> int:
|
||||
logger.info("Output saved to: %s", output_file)
|
||||
|
||||
# Deliver the final response to the origin/target chat
|
||||
deliver_content = (
|
||||
final_response if success else f"⚠️ Cron job '{job.get('name', job['id'])}' failed:\n{error}"
|
||||
)
|
||||
deliver_content = final_response if success else f"⚠️ Cron job '{job.get('name', job['id'])}' failed:\n{error}"
|
||||
if deliver_content:
|
||||
try:
|
||||
_deliver_result(job, deliver_content)
|
||||
@@ -382,7 +371,7 @@ def tick(verbose: bool = True) -> int:
|
||||
executed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing job %s: %s", job["id"], e)
|
||||
logger.error("Error processing job %s: %s", job['id'], e)
|
||||
mark_job_run(job["id"], False, str(e))
|
||||
|
||||
return executed
|
||||
@@ -392,7 +381,7 @@ def tick(verbose: bool = True) -> int:
|
||||
elif msvcrt:
|
||||
try:
|
||||
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
except OSError:
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
lock_fd.close()
|
||||
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
# datagen-config-examples/web_research.yaml
|
||||
#
|
||||
# Batch data generation config for WebResearchEnv.
|
||||
# Generates tool-calling trajectories for multi-step web research tasks.
|
||||
#
|
||||
# Usage:
|
||||
# python batch_runner.py \
|
||||
# --config datagen-config-examples/web_research.yaml \
|
||||
# --run_name web_research_v1
|
||||
|
||||
environment: web-research
|
||||
|
||||
# Toolsets available to the agent during data generation
|
||||
toolsets:
|
||||
- web
|
||||
- file
|
||||
|
||||
# How many parallel workers to use
|
||||
num_workers: 4
|
||||
|
||||
# Questions per batch
|
||||
batch_size: 20
|
||||
|
||||
# Total trajectories to generate (comment out to run full dataset)
|
||||
max_items: 500
|
||||
|
||||
# Model to use for generation (override with --model flag)
|
||||
model: openrouter/nousresearch/hermes-3-llama-3.1-405b
|
||||
|
||||
# System prompt additions (ephemeral — not saved to trajectories)
|
||||
ephemeral_system_prompt: |
|
||||
You are a highly capable research agent. When asked a factual question,
|
||||
always use web_search to find current, accurate information before answering.
|
||||
Cite at least 2 sources. Be concise and accurate.
|
||||
|
||||
# Output directory
|
||||
output_dir: data/web_research_v1
|
||||
|
||||
# Trajectory compression settings (for fitting into training token budgets)
|
||||
compression:
|
||||
enabled: true
|
||||
target_max_tokens: 16000
|
||||
|
||||
# Eval settings
|
||||
eval_every: 100 # Run eval every N trajectories
|
||||
eval_size: 25 # Number of held-out questions per eval run
|
||||
@@ -1,643 +0,0 @@
|
||||
"""
|
||||
WebResearchEnv — RL Environment for Multi-Step Web Research
|
||||
============================================================
|
||||
|
||||
Trains models to do accurate, efficient, multi-source web research.
|
||||
|
||||
Reward signals:
|
||||
- Answer correctness (LLM judge, 0.0–1.0)
|
||||
- Source diversity (used ≥2 distinct domains)
|
||||
- Efficiency (penalizes excessive tool calls)
|
||||
- Tool usage (bonus for actually using web tools)
|
||||
|
||||
Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
||||
HuggingFace: google/frames-benchmark
|
||||
Fallback: built-in sample questions (no HF token needed)
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenAI-compatible server)
|
||||
python environments/web_research_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai
|
||||
|
||||
# Process mode (offline data generation)
|
||||
python environments/web_research_env.py process \\
|
||||
--env.data_path_to_save_groups data/web_research.jsonl
|
||||
|
||||
# Standalone eval
|
||||
python environments/web_research_env.py evaluate \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel
|
||||
|
||||
Built by: github.com/jackx707
|
||||
Inspired by: GroceryMind — production Hermes agent doing live web research
|
||||
across German grocery stores (firecrawl + hermes-agent)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
# Ensure hermes-agent root is on path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional HuggingFace datasets import
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
||||
# Multi-hop questions requiring real web search to answer.
|
||||
# ---------------------------------------------------------------------------
|
||||
SAMPLE_QUESTIONS = [
|
||||
{
|
||||
"question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?",
|
||||
"answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?",
|
||||
"answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What programming language was used to write the original version of the web framework used by Instagram?",
|
||||
"answer": "Django, which Instagram was built on, is written in Python.",
|
||||
"difficulty": "easy",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?",
|
||||
"answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).",
|
||||
"difficulty": "hard",
|
||||
"hops": 3,
|
||||
},
|
||||
{
|
||||
"question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?",
|
||||
"answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "How many employees does the parent company of Instagram have?",
|
||||
"answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?",
|
||||
"answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Which company acquired the startup founded by the creator of Oculus VR?",
|
||||
"answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the market cap of the company that owns the most popular search engine in Russia?",
|
||||
"answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?",
|
||||
"answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnvConfig(HermesAgentEnvConfig):
|
||||
"""Configuration for the web research RL environment."""
|
||||
|
||||
# Reward weights
|
||||
correctness_weight: float = Field(
|
||||
default=0.6,
|
||||
description="Weight for answer correctness in reward (LLM judge score).",
|
||||
)
|
||||
tool_usage_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for tool usage signal (did the model actually use web tools?).",
|
||||
)
|
||||
efficiency_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for efficiency signal (penalizes excessive tool calls).",
|
||||
)
|
||||
diversity_bonus: float = Field(
|
||||
default=0.1,
|
||||
description="Bonus reward for citing ≥2 distinct domains.",
|
||||
)
|
||||
|
||||
# Efficiency thresholds
|
||||
efficient_max_calls: int = Field(
|
||||
default=5,
|
||||
description="Maximum tool calls before efficiency penalty begins.",
|
||||
)
|
||||
heavy_penalty_calls: int = Field(
|
||||
default=10,
|
||||
description="Tool call count where efficiency penalty steepens.",
|
||||
)
|
||||
|
||||
# Eval
|
||||
eval_size: int = Field(
|
||||
default=20,
|
||||
description="Number of held-out items for evaluation.",
|
||||
)
|
||||
eval_split_ratio: float = Field(
|
||||
default=0.1,
|
||||
description="Fraction of dataset to hold out for evaluation (0.0–1.0).",
|
||||
)
|
||||
|
||||
# Dataset
|
||||
dataset_name: str = Field(
|
||||
default="google/frames-benchmark",
|
||||
description="HuggingFace dataset name for research questions.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
RL environment for training multi-step web research skills.
|
||||
|
||||
The model is given a factual question requiring 2-3 hops of web research
|
||||
and must use web_search / web_extract tools to find and synthesize the answer.
|
||||
|
||||
Reward is multi-signal:
|
||||
60% — answer correctness (LLM judge)
|
||||
20% — tool usage (did the model actually search the web?)
|
||||
20% — efficiency (penalizes >5 tool calls)
|
||||
|
||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
||||
"""
|
||||
|
||||
name = "web-research"
|
||||
env_config_cls = WebResearchEnvConfig
|
||||
|
||||
# Default toolsets for this environment — web + file for saving notes
|
||||
default_toolsets = ["web", "file"]
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
|
||||
"""Default configuration for the web research environment."""
|
||||
env_config = WebResearchEnvConfig(
|
||||
enabled_toolsets=["web", "file"],
|
||||
max_agent_turns=15,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a highly capable research agent. When asked a factual question, "
|
||||
"always use web_search to find current, accurate information before answering. "
|
||||
"Cite at least 2 sources. Be concise and accurate."
|
||||
),
|
||||
group_size=4,
|
||||
total_steps=1000,
|
||||
steps_per_eval=100,
|
||||
use_wandb=True,
|
||||
wandb_name="web-research",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.5",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._items: list[dict] = []
|
||||
self._eval_items: list[dict] = []
|
||||
self._index: int = 0
|
||||
|
||||
# Metrics tracking for wandb
|
||||
self._reward_buffer: list[float] = []
|
||||
self._correctness_buffer: list[float] = []
|
||||
self._tool_usage_buffer: list[float] = []
|
||||
self._efficiency_buffer: list[float] = []
|
||||
self._diversity_buffer: list[float] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Setup — load dataset
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Load the FRAMES benchmark or fall back to built-in samples."""
|
||||
if HF_AVAILABLE:
|
||||
try:
|
||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
||||
ds = load_dataset(self.config.dataset_name, split="test")
|
||||
self._items = [
|
||||
{
|
||||
"question": row["Prompt"],
|
||||
"answer": row["Answer"],
|
||||
"difficulty": row.get("reasoning_types", "unknown"),
|
||||
"hops": 2,
|
||||
}
|
||||
for row in ds
|
||||
]
|
||||
# Hold out for eval
|
||||
eval_size = max(
|
||||
self.config.eval_size,
|
||||
int(len(self._items) * self.config.eval_split_ratio),
|
||||
)
|
||||
random.shuffle(self._items)
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
logger.info(
|
||||
f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items "
|
||||
f"from FRAMES benchmark."
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.")
|
||||
|
||||
# Fallback
|
||||
random.shuffle(SAMPLE_QUESTIONS)
|
||||
split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10)
|
||||
self._items = SAMPLE_QUESTIONS[:split]
|
||||
self._eval_items = SAMPLE_QUESTIONS[split:]
|
||||
logger.info(
|
||||
f"Using built-in sample dataset: {len(self._items)} train / "
|
||||
f"{len(self._eval_items)} eval items."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. get_next_item — return the next question
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_next_item(self) -> dict:
|
||||
"""Return the next item, cycling through the dataset."""
|
||||
if not self._items:
|
||||
raise RuntimeError("Dataset is empty. Did you call setup()?")
|
||||
item = self._items[self._index % len(self._items)]
|
||||
self._index += 1
|
||||
return item
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. format_prompt — build the user-facing prompt
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""Format the research question as a task prompt."""
|
||||
return (
|
||||
f"Research the following question thoroughly using web search. "
|
||||
f"You MUST search the web to find current, accurate information — "
|
||||
f"do not rely solely on your training data.\n\n"
|
||||
f"Question: {item['question']}\n\n"
|
||||
f"Requirements:\n"
|
||||
f"- Use web_search and/or web_extract tools to find information\n"
|
||||
f"- Search at least 2 different sources\n"
|
||||
f"- Provide a concise, accurate answer (2-4 sentences)\n"
|
||||
f"- Cite the sources you used"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. compute_reward — multi-signal scoring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def compute_reward(
|
||||
self,
|
||||
item: dict,
|
||||
result: AgentResult,
|
||||
ctx: ToolContext,
|
||||
) -> float:
|
||||
"""
|
||||
Multi-signal reward function:
|
||||
|
||||
correctness_weight * correctness — LLM judge comparing answer to ground truth
|
||||
tool_usage_weight * tool_used — binary: did the model use web tools?
|
||||
efficiency_weight * efficiency — penalizes wasteful tool usage
|
||||
+ diversity_bonus — source diversity (≥2 distinct domains)
|
||||
"""
|
||||
final_response: str = result.final_response or ""
|
||||
tools_used: list[str] = [
|
||||
tc.tool_name for tc in (result.tool_calls or [])
|
||||
] if hasattr(result, "tool_calls") and result.tool_calls else []
|
||||
tool_call_count: int = result.turns_used or len(tools_used)
|
||||
|
||||
cfg = self.config
|
||||
|
||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=final_response,
|
||||
)
|
||||
|
||||
# ---- Signal 2: Web tool usage --------------------------------
|
||||
web_tools = {"web_search", "web_extract", "search", "firecrawl"}
|
||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
||||
|
||||
# ---- Signal 3: Efficiency ------------------------------------
|
||||
if tool_call_count <= cfg.efficient_max_calls:
|
||||
efficiency = 1.0
|
||||
elif tool_call_count <= cfg.heavy_penalty_calls:
|
||||
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
|
||||
else:
|
||||
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
|
||||
|
||||
# ---- Bonus: Source diversity ---------------------------------
|
||||
domains = self._extract_domains(final_response)
|
||||
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
|
||||
|
||||
# ---- Combine ------------------------------------------------
|
||||
reward = (
|
||||
cfg.correctness_weight * correctness
|
||||
+ cfg.tool_usage_weight * tool_used
|
||||
+ cfg.efficiency_weight * efficiency
|
||||
+ diversity
|
||||
)
|
||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
||||
|
||||
# Track for wandb
|
||||
self._reward_buffer.append(reward)
|
||||
self._correctness_buffer.append(correctness)
|
||||
self._tool_usage_buffer.append(tool_used)
|
||||
self._efficiency_buffer.append(efficiency)
|
||||
self._diversity_buffer.append(diversity)
|
||||
|
||||
logger.debug(
|
||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
||||
f"diversity={diversity:.1f} → total={reward:.3f}"
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. evaluate — run on held-out eval split
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Run evaluation on the held-out split using the agent loop."""
|
||||
import time
|
||||
|
||||
items = self._eval_items
|
||||
if not items:
|
||||
logger.warning("No eval items available.")
|
||||
return
|
||||
|
||||
eval_size = min(self.config.eval_size, len(items))
|
||||
eval_items = items[:eval_size]
|
||||
|
||||
logger.info(f"Running eval on {len(eval_items)} questions...")
|
||||
start_time = time.time()
|
||||
samples = []
|
||||
|
||||
for item in eval_items:
|
||||
try:
|
||||
# Use the base env's agent loop for eval (same as training)
|
||||
prompt = self.format_prompt(item)
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_content = (
|
||||
completion.choices[0].message.content if completion.choices else ""
|
||||
)
|
||||
|
||||
# Score the response
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=response_content,
|
||||
)
|
||||
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": response_content,
|
||||
"expected": item["answer"],
|
||||
"correctness": correctness,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Eval error on item: {e}")
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": item["answer"],
|
||||
"correctness": 0.0,
|
||||
})
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Compute metrics
|
||||
correctness_scores = [s["correctness"] for s in samples]
|
||||
eval_metrics = {
|
||||
"eval/mean_correctness": (
|
||||
sum(correctness_scores) / len(correctness_scores)
|
||||
if correctness_scores else 0.0
|
||||
),
|
||||
"eval/n_items": len(samples),
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. wandb_log — custom metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||
"""Log reward breakdown metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self._reward_buffer:
|
||||
n = len(self._reward_buffer)
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
|
||||
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
|
||||
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
|
||||
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
|
||||
wandb_metrics["train/total_rollouts"] = n
|
||||
|
||||
# Accuracy buckets
|
||||
wandb_metrics["train/correct_rate"] = (
|
||||
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
|
||||
)
|
||||
wandb_metrics["train/tool_usage_rate"] = (
|
||||
sum(1 for t in self._tool_usage_buffer if t > 0) / n
|
||||
)
|
||||
|
||||
# Clear buffers
|
||||
self._reward_buffer.clear()
|
||||
self._correctness_buffer.clear()
|
||||
self._tool_usage_buffer.clear()
|
||||
self._efficiency_buffer.clear()
|
||||
self._diversity_buffer.clear()
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _llm_judge(
|
||||
self,
|
||||
question: str,
|
||||
expected: str,
|
||||
model_answer: str,
|
||||
) -> float:
|
||||
"""
|
||||
Use the server's LLM to judge answer correctness.
|
||||
Falls back to keyword heuristic if LLM call fails.
|
||||
"""
|
||||
if not model_answer or not model_answer.strip():
|
||||
return 0.0
|
||||
|
||||
judge_prompt = (
|
||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
||||
f"Question: {question}\n\n"
|
||||
f"Reference answer: {expected}\n\n"
|
||||
f"Model answer: {model_answer}\n\n"
|
||||
"Score the model answer on a scale from 0.0 to 1.0 where:\n"
|
||||
" 1.0 = fully correct and complete\n"
|
||||
" 0.7 = mostly correct with minor gaps\n"
|
||||
" 0.4 = partially correct\n"
|
||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
||||
" 0.0 = completely wrong or no answer\n\n"
|
||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
||||
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
n=1,
|
||||
max_tokens=150,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
text = response.choices[0].message.content if response.choices else ""
|
||||
parsed = self._parse_judge_json(text)
|
||||
if parsed is not None:
|
||||
return float(parsed)
|
||||
except Exception as e:
|
||||
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
|
||||
|
||||
return self._heuristic_score(expected, model_answer)
|
||||
|
||||
@staticmethod
|
||||
def _parse_judge_json(text: str) -> Optional[float]:
|
||||
"""Extract the score float from LLM judge JSON response."""
|
||||
try:
|
||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
||||
data = json.loads(clean)
|
||||
score = float(data.get("score", -1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
except Exception:
|
||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
||||
"""Lightweight keyword overlap score as fallback."""
|
||||
stopwords = {
|
||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
||||
"this", "that", "as", "by", "from", "be", "has", "have", "had",
|
||||
}
|
||||
|
||||
def tokenize(text: str) -> set:
|
||||
tokens = re.findall(r'\b\w+\b', text.lower())
|
||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
||||
|
||||
expected_tokens = tokenize(expected)
|
||||
answer_tokens = tokenize(model_answer)
|
||||
|
||||
if not expected_tokens:
|
||||
return 0.5
|
||||
|
||||
overlap = len(expected_tokens & answer_tokens)
|
||||
union = len(expected_tokens | answer_tokens)
|
||||
|
||||
jaccard = overlap / union if union > 0 else 0.0
|
||||
recall = overlap / len(expected_tokens)
|
||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
||||
|
||||
@staticmethod
|
||||
def _extract_domains(text: str) -> set:
|
||||
"""Extract unique domains from URLs cited in the response."""
|
||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
||||
domains = set()
|
||||
for url in urls:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc.lower().lstrip("www.")
|
||||
if domain:
|
||||
domains.add(domain)
|
||||
except Exception:
|
||||
pass
|
||||
return domains
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
WebResearchEnv.cli()
|
||||
@@ -9,18 +9,19 @@ to various messaging platforms (Telegram, Discord, WhatsApp) with:
|
||||
- Platform-specific toolsets (different capabilities per platform)
|
||||
"""
|
||||
|
||||
from .config import GatewayConfig, HomeChannel, PlatformConfig, SessionResetPolicy, load_gateway_config
|
||||
from .delivery import DeliveryRouter, DeliveryTarget
|
||||
from .config import GatewayConfig, PlatformConfig, HomeChannel, load_gateway_config
|
||||
from .session import (
|
||||
SessionContext,
|
||||
SessionStore,
|
||||
SessionResetPolicy,
|
||||
build_session_context_prompt,
|
||||
)
|
||||
from .delivery import DeliveryRouter, DeliveryTarget
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"GatewayConfig",
|
||||
"PlatformConfig",
|
||||
"PlatformConfig",
|
||||
"HomeChannel",
|
||||
"load_gateway_config",
|
||||
# Session
|
||||
|
||||
@@ -10,7 +10,7 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,8 +21,7 @@ DIRECTORY_PATH = Path.home() / ".hermes" / "channel_directory.json"
|
||||
# Build / refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_channel_directory(adapters: dict[Any, Any]) -> dict[str, Any]:
|
||||
def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a channel directory from connected platform adapters and session data.
|
||||
|
||||
@@ -30,7 +29,7 @@ def build_channel_directory(adapters: dict[Any, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
from gateway.config import Platform
|
||||
|
||||
platforms: dict[str, list[dict[str, str]]] = {}
|
||||
platforms: Dict[str, List[Dict[str, str]]] = {}
|
||||
|
||||
for platform, adapter in adapters.items():
|
||||
try:
|
||||
@@ -61,7 +60,7 @@ def build_channel_directory(adapters: dict[Any, Any]) -> dict[str, Any]:
|
||||
return directory
|
||||
|
||||
|
||||
def _build_discord(adapter) -> list[dict[str, str]]:
|
||||
def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
"""Enumerate all text channels the Discord bot can see."""
|
||||
channels = []
|
||||
client = getattr(adapter, "_client", None)
|
||||
@@ -75,14 +74,12 @@ def _build_discord(adapter) -> list[dict[str, str]]:
|
||||
|
||||
for guild in client.guilds:
|
||||
for ch in guild.text_channels:
|
||||
channels.append(
|
||||
{
|
||||
"id": str(ch.id),
|
||||
"name": ch.name,
|
||||
"guild": guild.name,
|
||||
"type": "channel",
|
||||
}
|
||||
)
|
||||
channels.append({
|
||||
"id": str(ch.id),
|
||||
"name": ch.name,
|
||||
"guild": guild.name,
|
||||
"type": "channel",
|
||||
})
|
||||
# Also include DM-capable users we've interacted with is not
|
||||
# feasible via guild enumeration; those come from sessions.
|
||||
|
||||
@@ -91,7 +88,7 @@ def _build_discord(adapter) -> list[dict[str, str]]:
|
||||
return channels
|
||||
|
||||
|
||||
def _build_slack(adapter) -> list[dict[str, str]]:
|
||||
def _build_slack(adapter) -> List[Dict[str, str]]:
|
||||
"""List Slack channels the bot has joined."""
|
||||
channels = []
|
||||
# Slack adapter may expose a web client
|
||||
@@ -100,6 +97,7 @@ def _build_slack(adapter) -> list[dict[str, str]]:
|
||||
return _build_from_sessions("slack")
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
from tools.send_message_tool import _send_slack # noqa: F401
|
||||
# Use the Slack Web API directly if available
|
||||
except Exception:
|
||||
@@ -109,7 +107,7 @@ def _build_slack(adapter) -> list[dict[str, str]]:
|
||||
return _build_from_sessions("slack")
|
||||
|
||||
|
||||
def _build_from_sessions(platform_name: str) -> list[dict[str, str]]:
|
||||
def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
|
||||
"""Pull known channels/contacts from sessions.json origin data."""
|
||||
sessions_path = Path.home() / ".hermes" / "sessions" / "sessions.json"
|
||||
if not sessions_path.exists():
|
||||
@@ -129,13 +127,11 @@ def _build_from_sessions(platform_name: str) -> list[dict[str, str]]:
|
||||
if not chat_id or chat_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(chat_id)
|
||||
entries.append(
|
||||
{
|
||||
"id": str(chat_id),
|
||||
"name": origin.get("chat_name") or origin.get("user_name") or str(chat_id),
|
||||
"type": session.get("chat_type", "dm"),
|
||||
}
|
||||
)
|
||||
entries.append({
|
||||
"id": str(chat_id),
|
||||
"name": origin.get("chat_name") or origin.get("user_name") or str(chat_id),
|
||||
"type": session.get("chat_type", "dm"),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug("Channel directory: failed to read sessions for %s: %s", platform_name, e)
|
||||
|
||||
@@ -146,8 +142,7 @@ def _build_from_sessions(platform_name: str) -> list[dict[str, str]]:
|
||||
# Read / resolve
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_directory() -> dict[str, Any]:
|
||||
def load_directory() -> Dict[str, Any]:
|
||||
"""Load the cached channel directory from disk."""
|
||||
if not DIRECTORY_PATH.exists():
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
@@ -158,7 +153,7 @@ def load_directory() -> dict[str, Any]:
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
|
||||
|
||||
def resolve_channel_name(platform_name: str, name: str) -> str | None:
|
||||
def resolve_channel_name(platform_name: str, name: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a human-friendly channel name to a numeric ID.
|
||||
|
||||
@@ -211,8 +206,8 @@ def format_directory_for_display() -> str:
|
||||
|
||||
# Group Discord channels by guild
|
||||
if plat_name == "discord":
|
||||
guilds: dict[str, list] = {}
|
||||
dms: list = []
|
||||
guilds: Dict[str, List] = {}
|
||||
dms: List = []
|
||||
for ch in channels:
|
||||
guild = ch.get("guild")
|
||||
if guild:
|
||||
|
||||
@@ -8,20 +8,19 @@ Handles loading and validating configuration for:
|
||||
- Delivery preferences
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Any
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Platform(Enum):
|
||||
"""Supported messaging platforms."""
|
||||
|
||||
LOCAL = "local"
|
||||
TELEGRAM = "telegram"
|
||||
DISCORD = "discord"
|
||||
@@ -35,24 +34,23 @@ class Platform(Enum):
|
||||
class HomeChannel:
|
||||
"""
|
||||
Default destination for a platform.
|
||||
|
||||
|
||||
When a cron job specifies deliver="telegram" without a specific chat ID,
|
||||
messages are sent to this home channel.
|
||||
"""
|
||||
|
||||
platform: Platform
|
||||
chat_id: str
|
||||
name: str # Human-readable name for display
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"platform": self.platform.value,
|
||||
"chat_id": self.chat_id,
|
||||
"name": self.name,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "HomeChannel":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HomeChannel":
|
||||
return cls(
|
||||
platform=Platform(data["platform"]),
|
||||
chat_id=str(data["chat_id"]),
|
||||
@@ -64,27 +62,26 @@ class HomeChannel:
|
||||
class SessionResetPolicy:
|
||||
"""
|
||||
Controls when sessions reset (lose context).
|
||||
|
||||
|
||||
Modes:
|
||||
- "daily": Reset at a specific hour each day
|
||||
- "idle": Reset after N minutes of inactivity
|
||||
- "both": Whichever triggers first (daily boundary OR idle timeout)
|
||||
- "none": Never auto-reset (context managed only by compression)
|
||||
"""
|
||||
|
||||
mode: str = "both" # "daily", "idle", "both", or "none"
|
||||
at_hour: int = 4 # Hour for daily reset (0-23, local time)
|
||||
idle_minutes: int = 1440 # Minutes of inactivity before reset (24 hours)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"mode": self.mode,
|
||||
"at_hour": self.at_hour,
|
||||
"idle_minutes": self.idle_minutes,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionResetPolicy":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionResetPolicy":
|
||||
return cls(
|
||||
mode=data.get("mode", "both"),
|
||||
at_hour=data.get("at_hour", 4),
|
||||
@@ -95,16 +92,15 @@ class SessionResetPolicy:
|
||||
@dataclass
|
||||
class PlatformConfig:
|
||||
"""Configuration for a single messaging platform."""
|
||||
|
||||
enabled: bool = False
|
||||
token: str | None = None # Bot token (Telegram, Discord)
|
||||
api_key: str | None = None # API key if different from token
|
||||
home_channel: HomeChannel | None = None
|
||||
|
||||
token: Optional[str] = None # Bot token (Telegram, Discord)
|
||||
api_key: Optional[str] = None # API key if different from token
|
||||
home_channel: Optional[HomeChannel] = None
|
||||
|
||||
# Platform-specific settings
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"enabled": self.enabled,
|
||||
"extra": self.extra,
|
||||
@@ -116,13 +112,13 @@ class PlatformConfig:
|
||||
if self.home_channel:
|
||||
result["home_channel"] = self.home_channel.to_dict()
|
||||
return result
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "PlatformConfig":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "PlatformConfig":
|
||||
home_channel = None
|
||||
if "home_channel" in data:
|
||||
home_channel = HomeChannel.from_dict(data["home_channel"])
|
||||
|
||||
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
token=data.get("token"),
|
||||
@@ -136,80 +132,89 @@ class PlatformConfig:
|
||||
class GatewayConfig:
|
||||
"""
|
||||
Main gateway configuration.
|
||||
|
||||
|
||||
Manages all platform connections, session policies, and delivery settings.
|
||||
"""
|
||||
|
||||
# Platform configurations
|
||||
platforms: dict[Platform, PlatformConfig] = field(default_factory=dict)
|
||||
|
||||
platforms: Dict[Platform, PlatformConfig] = field(default_factory=dict)
|
||||
|
||||
# Session reset policies by type
|
||||
default_reset_policy: SessionResetPolicy = field(default_factory=SessionResetPolicy)
|
||||
reset_by_type: dict[str, SessionResetPolicy] = field(default_factory=dict)
|
||||
reset_by_platform: dict[Platform, SessionResetPolicy] = field(default_factory=dict)
|
||||
|
||||
reset_by_type: Dict[str, SessionResetPolicy] = field(default_factory=dict)
|
||||
reset_by_platform: Dict[Platform, SessionResetPolicy] = field(default_factory=dict)
|
||||
|
||||
# Reset trigger commands
|
||||
reset_triggers: list[str] = field(default_factory=lambda: ["/new", "/reset"])
|
||||
|
||||
reset_triggers: List[str] = field(default_factory=lambda: ["/new", "/reset"])
|
||||
|
||||
# Storage paths
|
||||
sessions_dir: Path = field(default_factory=lambda: Path.home() / ".hermes" / "sessions")
|
||||
|
||||
|
||||
# Delivery settings
|
||||
always_log_local: bool = True # Always save cron outputs to local files
|
||||
|
||||
def get_connected_platforms(self) -> list[Platform]:
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
for platform, config in self.platforms.items():
|
||||
if not config.enabled:
|
||||
continue
|
||||
# Platforms that use token/api_key auth
|
||||
if (
|
||||
config.token
|
||||
or config.api_key
|
||||
or platform == Platform.WHATSAPP
|
||||
or platform == Platform.SIGNAL
|
||||
and config.extra.get("http_url")
|
||||
):
|
||||
if config.token or config.api_key:
|
||||
connected.append(platform)
|
||||
# WhatsApp uses enabled flag only (bridge handles auth)
|
||||
elif platform == Platform.WHATSAPP:
|
||||
connected.append(platform)
|
||||
# Signal uses extra dict for config (http_url + account)
|
||||
elif platform == Platform.SIGNAL and config.extra.get("http_url"):
|
||||
connected.append(platform)
|
||||
return connected
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> HomeChannel | None:
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
|
||||
"""Get the home channel for a platform."""
|
||||
config = self.platforms.get(platform)
|
||||
if config:
|
||||
return config.home_channel
|
||||
return None
|
||||
|
||||
def get_reset_policy(self, platform: Platform | None = None, session_type: str | None = None) -> SessionResetPolicy:
|
||||
|
||||
def get_reset_policy(
|
||||
self,
|
||||
platform: Optional[Platform] = None,
|
||||
session_type: Optional[str] = None
|
||||
) -> SessionResetPolicy:
|
||||
"""
|
||||
Get the appropriate reset policy for a session.
|
||||
|
||||
|
||||
Priority: platform override > type override > default
|
||||
"""
|
||||
# Platform-specific override takes precedence
|
||||
if platform and platform in self.reset_by_platform:
|
||||
return self.reset_by_platform[platform]
|
||||
|
||||
|
||||
# Type-specific override (dm, group, thread)
|
||||
if session_type and session_type in self.reset_by_type:
|
||||
return self.reset_by_type[session_type]
|
||||
|
||||
|
||||
return self.default_reset_policy
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"platforms": {p.value: c.to_dict() for p, c in self.platforms.items()},
|
||||
"platforms": {
|
||||
p.value: c.to_dict() for p, c in self.platforms.items()
|
||||
},
|
||||
"default_reset_policy": self.default_reset_policy.to_dict(),
|
||||
"reset_by_type": {k: v.to_dict() for k, v in self.reset_by_type.items()},
|
||||
"reset_by_platform": {p.value: v.to_dict() for p, v in self.reset_by_platform.items()},
|
||||
"reset_by_type": {
|
||||
k: v.to_dict() for k, v in self.reset_by_type.items()
|
||||
},
|
||||
"reset_by_platform": {
|
||||
p.value: v.to_dict() for p, v in self.reset_by_platform.items()
|
||||
},
|
||||
"reset_triggers": self.reset_triggers,
|
||||
"sessions_dir": str(self.sessions_dir),
|
||||
"always_log_local": self.always_log_local,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "GatewayConfig":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GatewayConfig":
|
||||
platforms = {}
|
||||
for platform_name, platform_data in data.get("platforms", {}).items():
|
||||
try:
|
||||
@@ -217,11 +222,11 @@ class GatewayConfig:
|
||||
platforms[platform] = PlatformConfig.from_dict(platform_data)
|
||||
except ValueError:
|
||||
pass # Skip unknown platforms
|
||||
|
||||
|
||||
reset_by_type = {}
|
||||
for type_name, policy_data in data.get("reset_by_type", {}).items():
|
||||
reset_by_type[type_name] = SessionResetPolicy.from_dict(policy_data)
|
||||
|
||||
|
||||
reset_by_platform = {}
|
||||
for platform_name, policy_data in data.get("reset_by_platform", {}).items():
|
||||
try:
|
||||
@@ -229,15 +234,15 @@ class GatewayConfig:
|
||||
reset_by_platform[platform] = SessionResetPolicy.from_dict(policy_data)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
default_policy = SessionResetPolicy()
|
||||
if "default_reset_policy" in data:
|
||||
default_policy = SessionResetPolicy.from_dict(data["default_reset_policy"])
|
||||
|
||||
|
||||
sessions_dir = Path.home() / ".hermes" / "sessions"
|
||||
if "sessions_dir" in data:
|
||||
sessions_dir = Path(data["sessions_dir"])
|
||||
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
@@ -252,7 +257,7 @@ class GatewayConfig:
|
||||
def load_gateway_config() -> GatewayConfig:
|
||||
"""
|
||||
Load gateway configuration from multiple sources.
|
||||
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. Environment variables
|
||||
2. ~/.hermes/gateway.json
|
||||
@@ -260,23 +265,22 @@ def load_gateway_config() -> GatewayConfig:
|
||||
4. Defaults
|
||||
"""
|
||||
config = GatewayConfig()
|
||||
|
||||
|
||||
# Try loading from ~/.hermes/gateway.json
|
||||
gateway_config_path = Path.home() / ".hermes" / "gateway.json"
|
||||
if gateway_config_path.exists():
|
||||
try:
|
||||
with open(gateway_config_path) as f:
|
||||
with open(gateway_config_path, "r") as f:
|
||||
data = json.load(f)
|
||||
config = GatewayConfig.from_dict(data)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load {gateway_config_path}: {e}")
|
||||
|
||||
|
||||
# Bridge session_reset from config.yaml (the user-facing config file)
|
||||
# into the gateway config. config.yaml takes precedence over gateway.json
|
||||
# for session reset policy since that's where hermes setup writes it.
|
||||
try:
|
||||
import yaml
|
||||
|
||||
config_yaml_path = Path.home() / ".hermes" / "config.yaml"
|
||||
if config_yaml_path.exists():
|
||||
with open(config_yaml_path) as f:
|
||||
@@ -289,12 +293,14 @@ def load_gateway_config() -> GatewayConfig:
|
||||
|
||||
# Override with environment variables
|
||||
_apply_env_overrides(config)
|
||||
|
||||
|
||||
# --- Validate loaded values ---
|
||||
policy = config.default_reset_policy
|
||||
|
||||
if not (0 <= policy.at_hour <= 23):
|
||||
logger.warning("Invalid at_hour=%s (must be 0-23). Using default 4.", policy.at_hour)
|
||||
logger.warning(
|
||||
"Invalid at_hour=%s (must be 0-23). Using default 4.", policy.at_hour
|
||||
)
|
||||
policy.at_hour = 4
|
||||
|
||||
if policy.idle_minutes is None or policy.idle_minutes <= 0:
|
||||
@@ -317,9 +323,9 @@ def load_gateway_config() -> GatewayConfig:
|
||||
env_name = _token_env_names.get(platform)
|
||||
if env_name and pconfig.token is not None and not pconfig.token.strip():
|
||||
logger.warning(
|
||||
"%s is enabled but %s is empty. The adapter will likely fail to connect.",
|
||||
platform.value,
|
||||
env_name,
|
||||
"%s is enabled but %s is empty. "
|
||||
"The adapter will likely fail to connect.",
|
||||
platform.value, env_name,
|
||||
)
|
||||
|
||||
return config
|
||||
@@ -327,7 +333,7 @@ def load_gateway_config() -> GatewayConfig:
|
||||
|
||||
def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
"""Apply environment variable overrides to config."""
|
||||
|
||||
|
||||
# Telegram
|
||||
telegram_token = os.getenv("TELEGRAM_BOT_TOKEN")
|
||||
if telegram_token:
|
||||
@@ -335,7 +341,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.TELEGRAM] = PlatformConfig()
|
||||
config.platforms[Platform.TELEGRAM].enabled = True
|
||||
config.platforms[Platform.TELEGRAM].token = telegram_token
|
||||
|
||||
|
||||
telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL")
|
||||
if telegram_home and Platform.TELEGRAM in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM].home_channel = HomeChannel(
|
||||
@@ -343,7 +349,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
chat_id=telegram_home,
|
||||
name=os.getenv("TELEGRAM_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
|
||||
# Discord
|
||||
discord_token = os.getenv("DISCORD_BOT_TOKEN")
|
||||
if discord_token:
|
||||
@@ -351,7 +357,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.DISCORD] = PlatformConfig()
|
||||
config.platforms[Platform.DISCORD].enabled = True
|
||||
config.platforms[Platform.DISCORD].token = discord_token
|
||||
|
||||
|
||||
discord_home = os.getenv("DISCORD_HOME_CHANNEL")
|
||||
if discord_home and Platform.DISCORD in config.platforms:
|
||||
config.platforms[Platform.DISCORD].home_channel = HomeChannel(
|
||||
@@ -359,14 +365,14 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
chat_id=discord_home,
|
||||
name=os.getenv("DISCORD_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
|
||||
# WhatsApp (typically uses different auth mechanism)
|
||||
whatsapp_enabled = os.getenv("WHATSAPP_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
if whatsapp_enabled:
|
||||
if Platform.WHATSAPP not in config.platforms:
|
||||
config.platforms[Platform.WHATSAPP] = PlatformConfig()
|
||||
config.platforms[Platform.WHATSAPP].enabled = True
|
||||
|
||||
|
||||
# Slack
|
||||
slack_token = os.getenv("SLACK_BOT_TOKEN")
|
||||
if slack_token:
|
||||
@@ -382,7 +388,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
chat_id=slack_home,
|
||||
name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""),
|
||||
)
|
||||
|
||||
|
||||
# Signal
|
||||
signal_url = os.getenv("SIGNAL_HTTP_URL")
|
||||
signal_account = os.getenv("SIGNAL_ACCOUNT")
|
||||
@@ -390,13 +396,11 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
if Platform.SIGNAL not in config.platforms:
|
||||
config.platforms[Platform.SIGNAL] = PlatformConfig()
|
||||
config.platforms[Platform.SIGNAL].enabled = True
|
||||
config.platforms[Platform.SIGNAL].extra.update(
|
||||
{
|
||||
"http_url": signal_url,
|
||||
"account": signal_account,
|
||||
"ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"),
|
||||
}
|
||||
)
|
||||
config.platforms[Platform.SIGNAL].extra.update({
|
||||
"http_url": signal_url,
|
||||
"account": signal_account,
|
||||
"ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"),
|
||||
})
|
||||
signal_home = os.getenv("SIGNAL_HOME_CHANNEL")
|
||||
if signal_home:
|
||||
config.platforms[Platform.SIGNAL].home_channel = HomeChannel(
|
||||
@@ -423,7 +427,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.default_reset_policy.idle_minutes = int(idle_minutes)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
reset_hour = os.getenv("SESSION_RESET_HOUR")
|
||||
if reset_hour:
|
||||
try:
|
||||
@@ -436,6 +440,6 @@ def save_gateway_config(config: GatewayConfig) -> None:
|
||||
"""Save gateway configuration to ~/.hermes/gateway.json."""
|
||||
gateway_config_path = Path.home() / ".hermes" / "gateway.json"
|
||||
gateway_config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
with open(gateway_config_path, "w") as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
|
||||
@@ -9,17 +9,18 @@ Routes messages to the appropriate destination based on:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_PLATFORM_OUTPUT = 4000
|
||||
TRUNCATED_VISIBLE = 3800
|
||||
|
||||
from .config import GatewayConfig, Platform
|
||||
from .config import Platform, GatewayConfig
|
||||
from .session import SessionSource
|
||||
|
||||
|
||||
@@ -27,24 +28,23 @@ from .session import SessionSource
|
||||
class DeliveryTarget:
|
||||
"""
|
||||
A single delivery target.
|
||||
|
||||
|
||||
Represents where a message should be sent:
|
||||
- "origin" → back to source
|
||||
- "local" → save to local files
|
||||
- "telegram" → Telegram home channel
|
||||
- "telegram:123456" → specific Telegram chat
|
||||
"""
|
||||
|
||||
platform: Platform
|
||||
chat_id: str | None = None # None means use home channel
|
||||
chat_id: Optional[str] = None # None means use home channel
|
||||
is_origin: bool = False
|
||||
is_explicit: bool = False # True if chat_id was explicitly specified
|
||||
|
||||
|
||||
@classmethod
|
||||
def parse(cls, target: str, origin: SessionSource | None = None) -> "DeliveryTarget":
|
||||
def parse(cls, target: str, origin: Optional[SessionSource] = None) -> "DeliveryTarget":
|
||||
"""
|
||||
Parse a delivery target string.
|
||||
|
||||
|
||||
Formats:
|
||||
- "origin" → back to source
|
||||
- "local" → local files only
|
||||
@@ -52,7 +52,7 @@ class DeliveryTarget:
|
||||
- "telegram:123456" → specific Telegram chat
|
||||
"""
|
||||
target = target.strip().lower()
|
||||
|
||||
|
||||
if target == "origin":
|
||||
if origin:
|
||||
return cls(
|
||||
@@ -63,10 +63,10 @@ class DeliveryTarget:
|
||||
else:
|
||||
# Fallback to local if no origin
|
||||
return cls(platform=Platform.LOCAL, is_origin=True)
|
||||
|
||||
|
||||
if target == "local":
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
|
||||
# Check for platform:chat_id format
|
||||
if ":" in target:
|
||||
platform_str, chat_id = target.split(":", 1)
|
||||
@@ -76,7 +76,7 @@ class DeliveryTarget:
|
||||
except ValueError:
|
||||
# Unknown platform, treat as local
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
|
||||
# Just a platform name (use home channel)
|
||||
try:
|
||||
platform = Platform(target)
|
||||
@@ -84,7 +84,7 @@ class DeliveryTarget:
|
||||
except ValueError:
|
||||
# Unknown platform, treat as local
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Convert back to string format."""
|
||||
if self.is_origin:
|
||||
@@ -99,15 +99,15 @@ class DeliveryTarget:
|
||||
class DeliveryRouter:
|
||||
"""
|
||||
Routes messages to appropriate destinations.
|
||||
|
||||
|
||||
Handles the logic of resolving delivery targets and dispatching
|
||||
messages to the right platform adapters.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GatewayConfig, adapters: dict[Platform, Any] = None):
|
||||
|
||||
def __init__(self, config: GatewayConfig, adapters: Dict[Platform, Any] = None):
|
||||
"""
|
||||
Initialize the delivery router.
|
||||
|
||||
|
||||
Args:
|
||||
config: Gateway configuration
|
||||
adapters: Dict mapping platforms to their adapter instances
|
||||
@@ -115,27 +115,31 @@ class DeliveryRouter:
|
||||
self.config = config
|
||||
self.adapters = adapters or {}
|
||||
self.output_dir = Path.home() / ".hermes" / "cron" / "output"
|
||||
|
||||
def resolve_targets(self, deliver: str | list[str], origin: SessionSource | None = None) -> list[DeliveryTarget]:
|
||||
|
||||
def resolve_targets(
|
||||
self,
|
||||
deliver: Union[str, List[str]],
|
||||
origin: Optional[SessionSource] = None
|
||||
) -> List[DeliveryTarget]:
|
||||
"""
|
||||
Resolve delivery specification to concrete targets.
|
||||
|
||||
|
||||
Args:
|
||||
deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc.
|
||||
origin: The source where the request originated (for "origin" target)
|
||||
|
||||
|
||||
Returns:
|
||||
List of resolved delivery targets
|
||||
"""
|
||||
if isinstance(deliver, str):
|
||||
deliver = [deliver]
|
||||
|
||||
|
||||
targets = []
|
||||
seen_platforms = set()
|
||||
|
||||
|
||||
for target_str in deliver:
|
||||
target = DeliveryTarget.parse(target_str, origin)
|
||||
|
||||
|
||||
# Resolve home channel if needed
|
||||
if target.chat_id is None and target.platform != Platform.LOCAL:
|
||||
home = self.config.get_home_channel(target.platform)
|
||||
@@ -144,96 +148,109 @@ class DeliveryRouter:
|
||||
else:
|
||||
# No home channel configured, skip this platform
|
||||
continue
|
||||
|
||||
|
||||
# Deduplicate
|
||||
key = (target.platform, target.chat_id)
|
||||
if key not in seen_platforms:
|
||||
seen_platforms.add(key)
|
||||
targets.append(target)
|
||||
|
||||
|
||||
# Always include local if configured
|
||||
if self.config.always_log_local:
|
||||
local_key = (Platform.LOCAL, None)
|
||||
if local_key not in seen_platforms:
|
||||
targets.append(DeliveryTarget(platform=Platform.LOCAL))
|
||||
|
||||
|
||||
return targets
|
||||
|
||||
|
||||
async def deliver(
|
||||
self,
|
||||
content: str,
|
||||
targets: list[DeliveryTarget],
|
||||
job_id: str | None = None,
|
||||
job_name: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
targets: List[DeliveryTarget],
|
||||
job_id: Optional[str] = None,
|
||||
job_name: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Deliver content to all specified targets.
|
||||
|
||||
|
||||
Args:
|
||||
content: The message/output to deliver
|
||||
targets: List of delivery targets
|
||||
job_id: Optional job ID (for cron jobs)
|
||||
job_name: Optional job name
|
||||
metadata: Additional metadata to include
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with delivery results per target
|
||||
"""
|
||||
results = {}
|
||||
|
||||
|
||||
for target in targets:
|
||||
try:
|
||||
if target.platform == Platform.LOCAL:
|
||||
result = self._deliver_local(content, job_id, job_name, metadata)
|
||||
else:
|
||||
result = await self._deliver_to_platform(target, content, metadata)
|
||||
|
||||
results[target.to_string()] = {"success": True, "result": result}
|
||||
|
||||
results[target.to_string()] = {
|
||||
"success": True,
|
||||
"result": result
|
||||
}
|
||||
except Exception as e:
|
||||
results[target.to_string()] = {"success": False, "error": str(e)}
|
||||
|
||||
results[target.to_string()] = {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _deliver_local(
|
||||
self, content: str, job_id: str | None, job_name: str | None, metadata: dict[str, Any] | None
|
||||
) -> dict[str, Any]:
|
||||
self,
|
||||
content: str,
|
||||
job_id: Optional[str],
|
||||
job_name: Optional[str],
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Save content to local files."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
if job_id:
|
||||
output_path = self.output_dir / job_id / f"{timestamp}.md"
|
||||
else:
|
||||
output_path = self.output_dir / "misc" / f"{timestamp}.md"
|
||||
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Build the output document
|
||||
lines = []
|
||||
if job_name:
|
||||
lines.append(f"# {job_name}")
|
||||
else:
|
||||
lines.append("# Delivery Output")
|
||||
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
|
||||
if job_id:
|
||||
lines.append(f"**Job ID:** {job_id}")
|
||||
|
||||
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
lines.append(f"**{key}:** {value}")
|
||||
|
||||
|
||||
lines.append("")
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
lines.append(content)
|
||||
|
||||
|
||||
output_path.write_text("\n".join(lines))
|
||||
|
||||
return {"path": str(output_path), "timestamp": timestamp}
|
||||
|
||||
|
||||
return {
|
||||
"path": str(output_path),
|
||||
"timestamp": timestamp
|
||||
}
|
||||
|
||||
def _save_full_output(self, content: str, job_id: str) -> Path:
|
||||
"""Save full cron output to disk and return the file path."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
@@ -244,33 +261,41 @@ class DeliveryRouter:
|
||||
return path
|
||||
|
||||
async def _deliver_to_platform(
|
||||
self, target: DeliveryTarget, content: str, metadata: dict[str, Any] | None
|
||||
) -> dict[str, Any]:
|
||||
self,
|
||||
target: DeliveryTarget,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Deliver content to a messaging platform."""
|
||||
adapter = self.adapters.get(target.platform)
|
||||
|
||||
|
||||
if not adapter:
|
||||
raise ValueError(f"No adapter configured for {target.platform.value}")
|
||||
|
||||
|
||||
if not target.chat_id:
|
||||
raise ValueError(f"No chat ID for {target.platform.value} delivery")
|
||||
|
||||
|
||||
# Guard: truncate oversized cron output to stay within platform limits
|
||||
if len(content) > MAX_PLATFORM_OUTPUT:
|
||||
job_id = (metadata or {}).get("job_id", "unknown")
|
||||
saved_path = self._save_full_output(content, job_id)
|
||||
logger.info("Cron output truncated (%d chars) — full output: %s", len(content), saved_path)
|
||||
content = content[:TRUNCATED_VISIBLE] + f"\n\n... [truncated, full output saved to {saved_path}]"
|
||||
|
||||
content = (
|
||||
content[:TRUNCATED_VISIBLE]
|
||||
+ f"\n\n... [truncated, full output saved to {saved_path}]"
|
||||
)
|
||||
|
||||
return await adapter.send(target.chat_id, content, metadata=metadata)
|
||||
|
||||
|
||||
def parse_deliver_spec(
|
||||
deliver: str | list[str] | None, origin: SessionSource | None = None, default: str = "origin"
|
||||
) -> str | list[str]:
|
||||
deliver: Optional[Union[str, List[str]]],
|
||||
origin: Optional[SessionSource] = None,
|
||||
default: str = "origin"
|
||||
) -> Union[str, List[str]]:
|
||||
"""
|
||||
Normalize a delivery specification.
|
||||
|
||||
|
||||
If None or empty, returns the default.
|
||||
"""
|
||||
if not deliver:
|
||||
@@ -278,14 +303,17 @@ def parse_deliver_spec(
|
||||
return deliver
|
||||
|
||||
|
||||
def build_delivery_context_for_tool(config: GatewayConfig, origin: SessionSource | None = None) -> dict[str, Any]:
|
||||
def build_delivery_context_for_tool(
|
||||
config: GatewayConfig,
|
||||
origin: Optional[SessionSource] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build context for the schedule_cronjob tool to understand delivery options.
|
||||
|
||||
|
||||
This is passed to the tool so it can validate and explain delivery targets.
|
||||
"""
|
||||
connected = config.get_connected_platforms()
|
||||
|
||||
|
||||
options = {
|
||||
"origin": {
|
||||
"description": "Back to where this job was created",
|
||||
@@ -294,9 +322,9 @@ def build_delivery_context_for_tool(config: GatewayConfig, origin: SessionSource
|
||||
"local": {
|
||||
"description": "Save to local files only",
|
||||
"available": True,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
for platform in connected:
|
||||
home = config.get_home_channel(platform)
|
||||
options[platform.value] = {
|
||||
@@ -304,7 +332,7 @@ def build_delivery_context_for_tool(config: GatewayConfig, origin: SessionSource
|
||||
"available": True,
|
||||
"home_channel": home.to_dict() if home else None,
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
"origin": origin.to_dict() if origin else None,
|
||||
"options": options,
|
||||
|
||||
@@ -21,12 +21,12 @@ Errors in hooks are caught and logged but never block the main pipeline.
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
HOOKS_DIR = Path(os.path.expanduser("~/.hermes/hooks"))
|
||||
|
||||
|
||||
@@ -42,11 +42,11 @@ class HookRegistry:
|
||||
|
||||
def __init__(self):
|
||||
# event_type -> [handler_fn, ...]
|
||||
self._handlers: dict[str, list[Callable]] = {}
|
||||
self._loaded_hooks: list[dict] = [] # metadata for listing
|
||||
self._handlers: Dict[str, List[Callable]] = {}
|
||||
self._loaded_hooks: List[dict] = [] # metadata for listing
|
||||
|
||||
@property
|
||||
def loaded_hooks(self) -> list[dict]:
|
||||
def loaded_hooks(self) -> List[dict]:
|
||||
"""Return metadata about all loaded hooks."""
|
||||
return list(self._loaded_hooks)
|
||||
|
||||
@@ -84,7 +84,9 @@ class HookRegistry:
|
||||
continue
|
||||
|
||||
# Dynamically load the handler module
|
||||
spec = importlib.util.spec_from_file_location(f"hermes_hook_{hook_name}", handler_path)
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
f"hermes_hook_{hook_name}", handler_path
|
||||
)
|
||||
if spec is None or spec.loader is None:
|
||||
print(f"[hooks] Skipping {hook_name}: could not load handler.py", flush=True)
|
||||
continue
|
||||
@@ -101,21 +103,19 @@ class HookRegistry:
|
||||
for event in events:
|
||||
self._handlers.setdefault(event, []).append(handle_fn)
|
||||
|
||||
self._loaded_hooks.append(
|
||||
{
|
||||
"name": hook_name,
|
||||
"description": manifest.get("description", ""),
|
||||
"events": events,
|
||||
"path": str(hook_dir),
|
||||
}
|
||||
)
|
||||
self._loaded_hooks.append({
|
||||
"name": hook_name,
|
||||
"description": manifest.get("description", ""),
|
||||
"events": events,
|
||||
"path": str(hook_dir),
|
||||
})
|
||||
|
||||
print(f"[hooks] Loaded hook '{hook_name}' for events: {events}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[hooks] Error loading hook {hook_dir.name}: {e}", flush=True)
|
||||
|
||||
async def emit(self, event_type: str, context: dict[str, Any] | None = None) -> None:
|
||||
async def emit(self, event_type: str, context: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""
|
||||
Fire all handlers registered for an event.
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,7 +61,7 @@ def mirror_to_session(
|
||||
return False
|
||||
|
||||
|
||||
def _find_session_id(platform: str, chat_id: str) -> str | None:
|
||||
def _find_session_id(platform: str, chat_id: str) -> Optional[str]:
|
||||
"""
|
||||
Find the active session_id for a platform + chat_id pair.
|
||||
|
||||
@@ -112,7 +113,6 @@ def _append_to_sqlite(session_id: str, message: dict) -> None:
|
||||
"""Append a message to the SQLite session database."""
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = SessionDB()
|
||||
db.append_message(
|
||||
session_id=session_id,
|
||||
|
||||
@@ -23,19 +23,21 @@ import os
|
||||
import secrets
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Unambiguous alphabet -- excludes 0/O, 1/I to prevent confusion
|
||||
ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
CODE_LENGTH = 8
|
||||
|
||||
# Timing constants
|
||||
CODE_TTL_SECONDS = 3600 # Codes expire after 1 hour
|
||||
RATE_LIMIT_SECONDS = 600 # 1 request per user per 10 minutes
|
||||
LOCKOUT_SECONDS = 3600 # Lockout duration after too many failures
|
||||
CODE_TTL_SECONDS = 3600 # Codes expire after 1 hour
|
||||
RATE_LIMIT_SECONDS = 600 # 1 request per user per 10 minutes
|
||||
LOCKOUT_SECONDS = 3600 # Lockout duration after too many failures
|
||||
|
||||
# Limits
|
||||
MAX_PENDING_PER_PLATFORM = 3 # Max pending codes per platform
|
||||
MAX_FAILED_ATTEMPTS = 5 # Failed approvals before lockout
|
||||
MAX_PENDING_PER_PLATFORM = 3 # Max pending codes per platform
|
||||
MAX_FAILED_ATTEMPTS = 5 # Failed approvals before lockout
|
||||
|
||||
PAIRING_DIR = Path(os.path.expanduser("~/.hermes/pairing"))
|
||||
|
||||
@@ -121,7 +123,9 @@ class PairingStore:
|
||||
|
||||
# ----- Pending codes -----
|
||||
|
||||
def generate_code(self, platform: str, user_id: str, user_name: str = "") -> str | None:
|
||||
def generate_code(
|
||||
self, platform: str, user_id: str, user_name: str = ""
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Generate a pairing code for a new user.
|
||||
|
||||
@@ -161,7 +165,7 @@ class PairingStore:
|
||||
|
||||
return code
|
||||
|
||||
def approve_code(self, platform: str, code: str) -> dict | None:
|
||||
def approve_code(self, platform: str, code: str) -> Optional[dict]:
|
||||
"""
|
||||
Approve a pairing code. Adds the user to the approved list.
|
||||
|
||||
@@ -195,15 +199,13 @@ class PairingStore:
|
||||
pending = self._load_json(self._pending_path(p))
|
||||
for code, info in pending.items():
|
||||
age_min = int((time.time() - info["created_at"]) / 60)
|
||||
results.append(
|
||||
{
|
||||
"platform": p,
|
||||
"code": code,
|
||||
"user_id": info["user_id"],
|
||||
"user_name": info.get("user_name", ""),
|
||||
"age_minutes": age_min,
|
||||
}
|
||||
)
|
||||
results.append({
|
||||
"platform": p,
|
||||
"code": code,
|
||||
"user_id": info["user_id"],
|
||||
"user_name": info.get("user_name", ""),
|
||||
"age_minutes": age_min,
|
||||
})
|
||||
return results
|
||||
|
||||
def clear_pending(self, platform: str = None) -> int:
|
||||
@@ -249,11 +251,8 @@ class PairingStore:
|
||||
lockout_key = f"_lockout:{platform}"
|
||||
limits[lockout_key] = time.time() + LOCKOUT_SECONDS
|
||||
limits[fail_key] = 0 # Reset counter
|
||||
print(
|
||||
f"[pairing] Platform {platform} locked out for {LOCKOUT_SECONDS}s "
|
||||
f"after {MAX_FAILED_ATTEMPTS} failed attempts",
|
||||
flush=True,
|
||||
)
|
||||
print(f"[pairing] Platform {platform} locked out for {LOCKOUT_SECONDS}s "
|
||||
f"after {MAX_FAILED_ATTEMPTS} failed attempts", flush=True)
|
||||
self._save_json(self._rate_limit_path(), limits)
|
||||
|
||||
# ----- Cleanup -----
|
||||
@@ -263,7 +262,10 @@ class PairingStore:
|
||||
path = self._pending_path(platform)
|
||||
pending = self._load_json(path)
|
||||
now = time.time()
|
||||
expired = [code for code, info in pending.items() if (now - info["created_at"]) > CODE_TTL_SECONDS]
|
||||
expired = [
|
||||
code for code, info in pending.items()
|
||||
if (now - info["created_at"]) > CODE_TTL_SECONDS
|
||||
]
|
||||
if expired:
|
||||
for code in expired:
|
||||
del pending[code]
|
||||
|
||||
@@ -303,8 +303,8 @@ Optional but valuable:
|
||||
After implementing everything, verify with:
|
||||
|
||||
```bash
|
||||
# All checks pass (lint + test)
|
||||
make check
|
||||
# All tests pass
|
||||
python -m pytest tests/ -q
|
||||
|
||||
# Grep for your platform name to find any missed integration points
|
||||
grep -r "telegram\|discord\|whatsapp\|slack" gateway/ tools/ agent/ cron/ hermes_cli/ toolsets.py \
|
||||
|
||||
@@ -13,20 +13,20 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from pathlib import Path as _Path
|
||||
from typing import Any
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image cache utilities
|
||||
#
|
||||
@@ -251,7 +251,6 @@ def cleanup_document_cache(max_age_hours: int = 24) -> int:
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Types of incoming messages."""
|
||||
|
||||
TEXT = "text"
|
||||
LOCATION = "location"
|
||||
PHOTO = "photo"
|
||||
@@ -267,43 +266,42 @@ class MessageType(Enum):
|
||||
class MessageEvent:
|
||||
"""
|
||||
Incoming message from a platform.
|
||||
|
||||
|
||||
Normalized representation that all adapters produce.
|
||||
"""
|
||||
|
||||
# Message content
|
||||
text: str
|
||||
message_type: MessageType = MessageType.TEXT
|
||||
|
||||
|
||||
# Source information
|
||||
source: SessionSource = None
|
||||
|
||||
|
||||
# Original platform data
|
||||
raw_message: Any = None
|
||||
message_id: str | None = None
|
||||
|
||||
message_id: Optional[str] = None
|
||||
|
||||
# Media attachments
|
||||
media_urls: list[str] = field(default_factory=list)
|
||||
media_types: list[str] = field(default_factory=list)
|
||||
|
||||
media_urls: List[str] = field(default_factory=list)
|
||||
media_types: List[str] = field(default_factory=list)
|
||||
|
||||
# Reply context
|
||||
reply_to_message_id: str | None = None
|
||||
|
||||
reply_to_message_id: Optional[str] = None
|
||||
|
||||
# Timestamps
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
def is_command(self) -> bool:
|
||||
"""Check if this is a command message (e.g., /new, /reset)."""
|
||||
return self.text.startswith("/")
|
||||
|
||||
def get_command(self) -> str | None:
|
||||
|
||||
def get_command(self) -> Optional[str]:
|
||||
"""Extract command name if this is a command message."""
|
||||
if not self.is_command():
|
||||
return None
|
||||
# Split on space and get first word, strip the /
|
||||
parts = self.text.split(maxsplit=1)
|
||||
return parts[0][1:].lower() if parts else None
|
||||
|
||||
|
||||
def get_command_args(self) -> str:
|
||||
"""Get the arguments after a command."""
|
||||
if not self.is_command():
|
||||
@@ -312,88 +310,91 @@ class MessageEvent:
|
||||
return parts[1] if len(parts) > 1 else ""
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class SendResult:
|
||||
"""Result of sending a message."""
|
||||
|
||||
success: bool
|
||||
message_id: str | None = None
|
||||
error: str | None = None
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
raw_response: Any = None
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
MessageHandler = Callable[[MessageEvent], Awaitable[str | None]]
|
||||
MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]]
|
||||
|
||||
|
||||
class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
Base class for platform adapters.
|
||||
|
||||
|
||||
Subclasses implement platform-specific logic for:
|
||||
- Connecting and authenticating
|
||||
- Receiving messages
|
||||
- Sending messages/responses
|
||||
- Handling media
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, config: PlatformConfig, platform: Platform):
|
||||
self.config = config
|
||||
self.platform = platform
|
||||
self._message_handler: MessageHandler | None = None
|
||||
self._message_handler: Optional[MessageHandler] = None
|
||||
self._running = False
|
||||
|
||||
|
||||
# Track active message handlers per session for interrupt support
|
||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||
self._active_sessions: dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: dict[str, MessageEvent] = {}
|
||||
|
||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable name for this adapter."""
|
||||
return self.platform.value.title()
|
||||
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if adapter is currently connected."""
|
||||
return self._running
|
||||
|
||||
|
||||
def set_message_handler(self, handler: MessageHandler) -> None:
|
||||
"""
|
||||
Set the handler for incoming messages.
|
||||
|
||||
|
||||
The handler receives a MessageEvent and should return
|
||||
an optional response string.
|
||||
"""
|
||||
self._message_handler = handler
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Connect to the platform and start receiving messages.
|
||||
|
||||
|
||||
Returns True if connection was successful.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the platform."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def send(
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a message to a chat.
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID to send to
|
||||
content: Message content (may be markdown)
|
||||
reply_to: Optional message ID to reply to
|
||||
metadata: Additional platform-specific options
|
||||
|
||||
|
||||
Returns:
|
||||
SendResult with success status and message ID
|
||||
"""
|
||||
@@ -415,21 +416,21 @@ class BasePlatformAdapter(ABC):
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""
|
||||
Send a typing indicator.
|
||||
|
||||
|
||||
Override in subclasses if the platform supports it.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an image natively via the platform API.
|
||||
|
||||
|
||||
Override in subclasses to send images as proper attachments
|
||||
instead of plain-text URLs. Default falls back to sending the
|
||||
URL as a text message.
|
||||
@@ -437,91 +438,87 @@ class BasePlatformAdapter(ABC):
|
||||
# Fallback: send URL as text (subclasses override for native images)
|
||||
text = f"{caption}\n{image_url}" if caption else image_url
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an animated GIF natively via the platform API.
|
||||
|
||||
|
||||
Override in subclasses to send GIFs as proper animations
|
||||
(e.g., Telegram send_animation) so they auto-play inline.
|
||||
Default falls back to send_image.
|
||||
"""
|
||||
return await self.send_image(chat_id=chat_id, image_url=animation_url, caption=caption, reply_to=reply_to)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _is_animation_url(url: str) -> bool:
|
||||
"""Check if a URL points to an animated GIF (vs a static image)."""
|
||||
lower = url.lower().split("?")[0] # Strip query params
|
||||
return lower.endswith(".gif")
|
||||
lower = url.lower().split('?')[0] # Strip query params
|
||||
return lower.endswith('.gif')
|
||||
|
||||
@staticmethod
|
||||
def extract_images(content: str) -> tuple[list[tuple[str, str]], str]:
|
||||
def extract_images(content: str) -> Tuple[List[Tuple[str, str]], str]:
|
||||
"""
|
||||
Extract image URLs from markdown and HTML image tags in a response.
|
||||
|
||||
|
||||
Finds patterns like:
|
||||
- 
|
||||
- <img src="https://example.com/image.png">
|
||||
- <img src="https://example.com/image.png"></img>
|
||||
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (url, alt_text) pairs, cleaned content with image tags removed).
|
||||
"""
|
||||
images = []
|
||||
cleaned = content
|
||||
|
||||
|
||||
# Match markdown images: 
|
||||
md_pattern = r"!\[([^\]]*)\]\((https?://[^\s\)]+)\)"
|
||||
md_pattern = r'!\[([^\]]*)\]\((https?://[^\s\)]+)\)'
|
||||
for match in re.finditer(md_pattern, content):
|
||||
alt_text = match.group(1)
|
||||
url = match.group(2)
|
||||
# Only extract URLs that look like actual images
|
||||
if any(
|
||||
url.lower().endswith(ext) or ext in url.lower()
|
||||
for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp", "fal.media", "fal-cdn", "replicate.delivery"]
|
||||
):
|
||||
if any(url.lower().endswith(ext) or ext in url.lower() for ext in
|
||||
['.png', '.jpg', '.jpeg', '.gif', '.webp', 'fal.media', 'fal-cdn', 'replicate.delivery']):
|
||||
images.append((url, alt_text))
|
||||
|
||||
|
||||
# Match HTML img tags: <img src="url"> or <img src="url"></img> or <img src="url"/>
|
||||
html_pattern = r'<img\s+src=["\']?(https?://[^\s"\'<>]+)["\']?\s*/?>\s*(?:</img>)?'
|
||||
for match in re.finditer(html_pattern, content):
|
||||
url = match.group(1)
|
||||
images.append((url, ""))
|
||||
|
||||
|
||||
# Remove only the matched image tags from content (not all markdown images)
|
||||
if images:
|
||||
extracted_urls = {url for url, _ in images}
|
||||
|
||||
def _remove_if_extracted(match):
|
||||
url = match.group(2) if match.lastindex >= 2 else match.group(1)
|
||||
return "" if url in extracted_urls else match.group(0)
|
||||
|
||||
return '' if url in extracted_urls else match.group(0)
|
||||
cleaned = re.sub(md_pattern, _remove_if_extracted, cleaned)
|
||||
cleaned = re.sub(html_pattern, _remove_if_extracted, cleaned)
|
||||
# Clean up leftover blank lines
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip()
|
||||
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
return images, cleaned
|
||||
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an audio file as a native voice message via the platform API.
|
||||
|
||||
|
||||
Override in subclasses to send audio as voice bubbles (Telegram)
|
||||
or file attachments (Discord). Default falls back to sending the
|
||||
file path as text.
|
||||
@@ -535,8 +532,8 @@ class BasePlatformAdapter(ABC):
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a video natively via the platform API.
|
||||
@@ -553,9 +550,9 @@ class BasePlatformAdapter(ABC):
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a document/file natively via the platform API.
|
||||
@@ -572,8 +569,8 @@ class BasePlatformAdapter(ABC):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a local image file natively via the platform API.
|
||||
@@ -588,45 +585,45 @@ class BasePlatformAdapter(ABC):
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
@staticmethod
|
||||
def extract_media(content: str) -> tuple[list[tuple[str, bool]], str]:
|
||||
def extract_media(content: str) -> Tuple[List[Tuple[str, bool]], str]:
|
||||
"""
|
||||
Extract MEDIA:<path> tags and [[audio_as_voice]] directives from response text.
|
||||
|
||||
|
||||
The TTS tool returns responses like:
|
||||
[[audio_as_voice]]
|
||||
MEDIA:/path/to/audio.ogg
|
||||
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (path, is_voice) pairs, cleaned content with tags removed).
|
||||
"""
|
||||
media = []
|
||||
cleaned = content
|
||||
|
||||
|
||||
# Check for [[audio_as_voice]] directive
|
||||
has_voice_tag = "[[audio_as_voice]]" in content
|
||||
cleaned = cleaned.replace("[[audio_as_voice]]", "")
|
||||
|
||||
|
||||
# Extract MEDIA:<path> tags (path may contain spaces)
|
||||
media_pattern = r"MEDIA:(\S+)"
|
||||
media_pattern = r'MEDIA:(\S+)'
|
||||
for match in re.finditer(media_pattern, content):
|
||||
path = match.group(1).strip()
|
||||
if path:
|
||||
media.append((path, has_voice_tag))
|
||||
|
||||
|
||||
# Remove MEDIA tags from content
|
||||
if media:
|
||||
cleaned = re.sub(media_pattern, "", cleaned)
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip()
|
||||
|
||||
cleaned = re.sub(media_pattern, '', cleaned)
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
return media, cleaned
|
||||
|
||||
|
||||
async def _keep_typing(self, chat_id: str, interval: float = 2.0) -> None:
|
||||
"""
|
||||
Continuously send typing indicator until cancelled.
|
||||
|
||||
|
||||
Telegram/Discord typing status expires after ~5 seconds, so we refresh every 2
|
||||
to recover quickly after progress messages interrupt it.
|
||||
"""
|
||||
@@ -636,20 +633,20 @@ class BasePlatformAdapter(ABC):
|
||||
await asyncio.sleep(interval)
|
||||
except asyncio.CancelledError:
|
||||
pass # Normal cancellation when handler completes
|
||||
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
|
||||
|
||||
This method returns quickly by spawning background tasks.
|
||||
This allows new messages to be processed even while an agent is running,
|
||||
enabling interruption support.
|
||||
"""
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
|
||||
session_key = event.source.chat_id
|
||||
|
||||
|
||||
# Check if there's already an active handler for this session
|
||||
if session_key in self._active_sessions:
|
||||
# Store this as a pending message - it will interrupt the running agent
|
||||
@@ -658,10 +655,10 @@ class BasePlatformAdapter(ABC):
|
||||
# Signal the interrupt (the processing task checks this)
|
||||
self._active_sessions[session_key].set()
|
||||
return # Don't process now - will be handled after current task finishes
|
||||
|
||||
|
||||
# Spawn background task to process this message
|
||||
asyncio.create_task(self._process_message_background(event, session_key))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_human_delay() -> float:
|
||||
"""
|
||||
@@ -688,40 +685,35 @@ class BasePlatformAdapter(ABC):
|
||||
# Create interrupt event for this session
|
||||
interrupt_event = asyncio.Event()
|
||||
self._active_sessions[session_key] = interrupt_event
|
||||
|
||||
|
||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id))
|
||||
|
||||
|
||||
try:
|
||||
# Call the handler (this can take a while with tool calls)
|
||||
response = await self._message_handler(event)
|
||||
|
||||
|
||||
# Send response if any
|
||||
if not response:
|
||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if response:
|
||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||
media_files, response = self.extract_media(response)
|
||||
|
||||
|
||||
# Extract image URLs and send them as native platform attachments
|
||||
images, text_content = self.extract_images(response)
|
||||
if images:
|
||||
logger.info(
|
||||
"[%s] extract_images found %d image(s) in response (%d chars)",
|
||||
self.name,
|
||||
len(images),
|
||||
len(response),
|
||||
)
|
||||
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
logger.info(
|
||||
"[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id
|
||||
)
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id, content=text_content, reply_to=event.message_id
|
||||
chat_id=event.source.chat_id,
|
||||
content=text_content,
|
||||
reply_to=event.message_id
|
||||
)
|
||||
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
@@ -729,14 +721,14 @@ class BasePlatformAdapter(ABC):
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id,
|
||||
reply_to=event.message_id
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
|
||||
# Send extracted images as native attachments
|
||||
if images:
|
||||
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
|
||||
@@ -744,12 +736,7 @@ class BasePlatformAdapter(ABC):
|
||||
if human_delay > 0:
|
||||
await asyncio.sleep(human_delay)
|
||||
try:
|
||||
logger.info(
|
||||
"[%s] Sending image: %s (alt=%s)",
|
||||
self.name,
|
||||
image_url[:80],
|
||||
alt_text[:30] if alt_text else "",
|
||||
)
|
||||
logger.info("[%s] Sending image: %s (alt=%s)", self.name, image_url[:80], alt_text[:30] if alt_text else "")
|
||||
# Route animated GIFs through send_animation for proper playback
|
||||
if self._is_animation_url(image_url):
|
||||
img_result = await self.send_animation(
|
||||
@@ -767,11 +754,11 @@ class BasePlatformAdapter(ABC):
|
||||
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
|
||||
except Exception as img_err:
|
||||
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
|
||||
|
||||
|
||||
# Send extracted media files — route by file type
|
||||
_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".3gp"}
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
||||
_VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.3gp'}
|
||||
_IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'}
|
||||
|
||||
for media_path, is_voice in media_files:
|
||||
if human_delay > 0:
|
||||
@@ -803,7 +790,7 @@ class BasePlatformAdapter(ABC):
|
||||
print(f"[{self.name}] Failed to send media ({ext}): {media_result.error}")
|
||||
except Exception as media_err:
|
||||
print(f"[{self.name}] Error sending media: {media_err}")
|
||||
|
||||
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
pending_event = self._pending_messages.pop(session_key)
|
||||
@@ -819,11 +806,10 @@ class BasePlatformAdapter(ABC):
|
||||
# Process pending message in new background task
|
||||
await self._process_message_background(pending_event, session_key)
|
||||
return # Already cleaned up
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error handling message: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Stop typing indicator
|
||||
@@ -835,26 +821,26 @@ class BasePlatformAdapter(ABC):
|
||||
# Clean up session tracking
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
|
||||
|
||||
def has_pending_interrupt(self, session_key: str) -> bool:
|
||||
"""Check if there's a pending interrupt for a session."""
|
||||
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
||||
|
||||
def get_pending_message(self, session_key: str) -> MessageEvent | None:
|
||||
|
||||
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
|
||||
"""Get and clear any pending message for a session."""
|
||||
return self._pending_messages.pop(session_key, None)
|
||||
|
||||
|
||||
def build_source(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_name: str | None = None,
|
||||
chat_name: Optional[str] = None,
|
||||
chat_type: str = "dm",
|
||||
user_id: str | None = None,
|
||||
user_name: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
chat_topic: str | None = None,
|
||||
user_id_alt: str | None = None,
|
||||
chat_id_alt: str | None = None,
|
||||
user_id: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
chat_topic: Optional[str] = None,
|
||||
user_id_alt: Optional[str] = None,
|
||||
chat_id_alt: Optional[str] = None,
|
||||
) -> SessionSource:
|
||||
"""Helper to build a SessionSource for this platform."""
|
||||
# Normalize empty topic to None
|
||||
@@ -872,30 +858,30 @@ class BasePlatformAdapter(ABC):
|
||||
user_id_alt=user_id_alt,
|
||||
chat_id_alt=chat_id_alt,
|
||||
)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a chat/channel.
|
||||
|
||||
|
||||
Returns dict with at least:
|
||||
- name: Chat name
|
||||
- type: "dm", "group", "channel"
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format a message for this platform.
|
||||
|
||||
|
||||
Override in subclasses to handle platform-specific formatting
|
||||
(e.g., Telegram MarkdownV2, Discord markdown).
|
||||
|
||||
|
||||
Default implementation returns content as-is.
|
||||
"""
|
||||
return content
|
||||
|
||||
def truncate_message(self, content: str, max_length: int = 4096) -> list[str]:
|
||||
|
||||
def truncate_message(self, content: str, max_length: int = 4096) -> List[str]:
|
||||
"""
|
||||
Split a long message into chunks, preserving code block boundaries.
|
||||
|
||||
@@ -914,14 +900,14 @@ class BasePlatformAdapter(ABC):
|
||||
if len(content) <= max_length:
|
||||
return [content]
|
||||
|
||||
INDICATOR_RESERVE = 10 # room for " (XX/XX)"
|
||||
INDICATOR_RESERVE = 10 # room for " (XX/XX)"
|
||||
FENCE_CLOSE = "\n```"
|
||||
|
||||
chunks: list[str] = []
|
||||
chunks: List[str] = []
|
||||
remaining = content
|
||||
# When the previous chunk ended mid-code-block, this holds the
|
||||
# language tag (possibly "") so we can reopen the fence.
|
||||
carry_lang: str | None = None
|
||||
carry_lang: Optional[str] = None
|
||||
|
||||
while remaining:
|
||||
# If we're continuing a code block from the previous chunk,
|
||||
@@ -979,6 +965,8 @@ class BasePlatformAdapter(ABC):
|
||||
# Append chunk indicators when the response spans multiple messages
|
||||
if len(chunks) > 1:
|
||||
total = len(chunks)
|
||||
chunks = [f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)]
|
||||
chunks = [
|
||||
f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)
|
||||
]
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -10,16 +10,14 @@ Uses discord.py library for:
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import discord
|
||||
from discord import Intents
|
||||
from discord import Message as DiscordMessage
|
||||
from discord import Message as DiscordMessage, Intents
|
||||
from discord.ext import commands
|
||||
|
||||
DISCORD_AVAILABLE = True
|
||||
except ImportError:
|
||||
DISCORD_AVAILABLE = False
|
||||
@@ -30,7 +28,6 @@ except ImportError:
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
@@ -39,8 +36,8 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_audio_from_url,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
)
|
||||
|
||||
|
||||
@@ -52,7 +49,7 @@ def check_discord_requirements() -> bool:
|
||||
class DiscordAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Discord bot adapter.
|
||||
|
||||
|
||||
Handles:
|
||||
- Receiving messages from servers and DMs
|
||||
- Sending responses with Discord markdown
|
||||
@@ -62,26 +59,26 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
- Auto-threading for long conversations
|
||||
- Reaction-based feedback
|
||||
"""
|
||||
|
||||
|
||||
# Discord message limits
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.DISCORD)
|
||||
self._client: commands.Bot | None = None
|
||||
self._client: Optional[commands.Bot] = None
|
||||
self._ready_event = asyncio.Event()
|
||||
self._allowed_user_ids: set = set() # For button approval authorization
|
||||
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
print(f"[{self.name}] discord.py not installed. Run: pip install discord.py")
|
||||
return False
|
||||
|
||||
|
||||
if not self.config.token:
|
||||
print(f"[{self.name}] No bot token configured")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Set up intents -- members intent needed for username-to-ID resolution
|
||||
intents = Intents.default()
|
||||
@@ -89,28 +86,30 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
intents.dm_messages = True
|
||||
intents.guild_messages = True
|
||||
intents.members = True
|
||||
|
||||
|
||||
# Create bot
|
||||
self._client = commands.Bot(
|
||||
command_prefix="!", # Not really used, we handle raw messages
|
||||
intents=intents,
|
||||
)
|
||||
|
||||
|
||||
# Parse allowed user entries (may contain usernames or IDs)
|
||||
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed_env:
|
||||
self._allowed_user_ids = {uid.strip() for uid in allowed_env.split(",") if uid.strip()}
|
||||
|
||||
self._allowed_user_ids = {
|
||||
uid.strip() for uid in allowed_env.split(",") if uid.strip()
|
||||
}
|
||||
|
||||
adapter_self = self # capture for closure
|
||||
|
||||
|
||||
# Register event handlers
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
print(f"[{adapter_self.name}] Connected as {adapter_self._client.user}")
|
||||
|
||||
|
||||
# Resolve any usernames in the allowed list to numeric IDs
|
||||
await adapter_self._resolve_allowed_usernames()
|
||||
|
||||
|
||||
# Sync slash commands with Discord
|
||||
try:
|
||||
synced = await adapter_self._client.tree.sync()
|
||||
@@ -118,33 +117,33 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{adapter_self.name}] Slash command sync failed: {e}")
|
||||
adapter_self._ready_event.set()
|
||||
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Ignore bot's own messages
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
await self._handle_message(message)
|
||||
|
||||
|
||||
# Register slash commands
|
||||
self._register_slash_commands()
|
||||
|
||||
|
||||
# Start the bot in background
|
||||
asyncio.create_task(self._client.start(self.config.token))
|
||||
|
||||
|
||||
# Wait for ready
|
||||
await asyncio.wait_for(self._ready_event.wait(), timeout=30)
|
||||
|
||||
|
||||
self._running = True
|
||||
return True
|
||||
|
||||
except TimeoutError:
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(f"[{self.name}] Timeout waiting for connection")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Discord."""
|
||||
if self._client:
|
||||
@@ -152,55 +151,59 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
await self._client.close()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during disconnect: {e}")
|
||||
|
||||
|
||||
self._running = False
|
||||
self._client = None
|
||||
self._ready_event.clear()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
|
||||
async def send(
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Discord channel."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Get the channel
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
|
||||
message_ids = []
|
||||
reference = None
|
||||
|
||||
|
||||
if reply_to:
|
||||
try:
|
||||
ref_msg = await channel.fetch_message(int(reply_to))
|
||||
reference = ref_msg
|
||||
except Exception as e:
|
||||
logger.debug("Could not fetch reply-to message: %s", e)
|
||||
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
reference=reference if i == 0 else None,
|
||||
)
|
||||
message_ids.append(str(msg.id))
|
||||
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids},
|
||||
raw_response={"message_ids": message_ids}
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@@ -220,7 +223,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
msg = await channel.fetch_message(int(message_id))
|
||||
formatted = self.format_message(content)
|
||||
if len(formatted) > self.MAX_MESSAGE_LENGTH:
|
||||
formatted = formatted[: self.MAX_MESSAGE_LENGTH - 3] + "..."
|
||||
formatted = formatted[:self.MAX_MESSAGE_LENGTH - 3] + "..."
|
||||
await msg.edit(content=formatted)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as e:
|
||||
@@ -230,28 +233,28 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
|
||||
# Determine filename from path
|
||||
filename = os.path.basename(audio_path)
|
||||
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
@@ -259,36 +262,36 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
|
||||
filename = os.path.basename(image_path)
|
||||
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
@@ -296,7 +299,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
@@ -305,31 +308,31 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
# Download the image and send as a Discord file attachment
|
||||
# (Discord renders attachments inline, unlike plain URLs)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to download image: HTTP {resp.status}")
|
||||
|
||||
|
||||
image_data = await resp.read()
|
||||
|
||||
|
||||
# Determine filename from URL or content type
|
||||
content_type = resp.headers.get("content-type", "image/png")
|
||||
ext = "png"
|
||||
@@ -339,24 +342,23 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
ext = "gif"
|
||||
elif "webp" in content_type:
|
||||
ext = "webp"
|
||||
|
||||
|
||||
import io
|
||||
|
||||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||||
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, falling back to URL. Run: pip install aiohttp")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send image attachment, falling back to URL: {e}")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._client:
|
||||
@@ -366,20 +368,20 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
await channel.typing()
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Discord channel."""
|
||||
if not self._client:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
|
||||
if not channel:
|
||||
return {"name": str(chat_id), "type": "dm"}
|
||||
|
||||
|
||||
# Determine channel type
|
||||
if isinstance(channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
@@ -395,7 +397,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
chat_type = "channel"
|
||||
name = getattr(channel, "name", str(chat_id))
|
||||
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"type": chat_type,
|
||||
@@ -404,7 +406,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
}
|
||||
except Exception as e:
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
|
||||
async def _resolve_allowed_usernames(self) -> None:
|
||||
"""
|
||||
Resolve non-numeric entries in DISCORD_ALLOWED_USERS to Discord user IDs.
|
||||
@@ -451,10 +453,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
uid = str(member.id)
|
||||
numeric_ids.add(uid)
|
||||
resolved_count += 1
|
||||
matched_name = (
|
||||
name_lower
|
||||
if name_lower in to_resolve
|
||||
else (display_lower if display_lower in to_resolve else global_lower)
|
||||
matched_name = name_lower if name_lower in to_resolve else (
|
||||
display_lower if display_lower in to_resolve else global_lower
|
||||
)
|
||||
to_resolve.discard(matched_name)
|
||||
print(f"[{self.name}] Resolved '{matched_name}' -> {uid} ({member.name}#{member.discriminator})")
|
||||
@@ -474,12 +474,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format message for Discord.
|
||||
|
||||
|
||||
Discord uses its own markdown variant.
|
||||
"""
|
||||
# Discord markdown is fairly standard, no special escaping needed
|
||||
return content
|
||||
|
||||
|
||||
def _register_slash_commands(self) -> None:
|
||||
"""Register Discord slash commands on the command tree."""
|
||||
if not self._client:
|
||||
@@ -694,7 +694,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_name = interaction.channel.name
|
||||
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
|
||||
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
|
||||
|
||||
|
||||
# Get channel topic (if available)
|
||||
chat_topic = getattr(interaction.channel, "topic", None)
|
||||
|
||||
@@ -715,7 +715,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
raw_message=interaction,
|
||||
)
|
||||
|
||||
async def send_exec_approval(self, chat_id: str, command: str, approval_id: str) -> SendResult:
|
||||
async def send_exec_approval(
|
||||
self, chat_id: str, command: str, approval_id: str
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a button-based exec approval prompt for a dangerous command.
|
||||
|
||||
@@ -757,28 +759,28 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# bot responds to every message without needing a mention.
|
||||
# DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement
|
||||
# globally (all channels become free-response). Default: "true".
|
||||
|
||||
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
# Check if this channel is in the free-response list
|
||||
free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
|
||||
channel_id = str(message.channel.id)
|
||||
|
||||
|
||||
# Global override: if DISCORD_REQUIRE_MENTION=false, all channels are free
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
|
||||
|
||||
is_free_channel = channel_id in free_channels
|
||||
|
||||
|
||||
if require_mention and not is_free_channel:
|
||||
# Must be @mentioned to respond
|
||||
if self._client.user not in message.mentions:
|
||||
return # Silently ignore messages that don't mention the bot
|
||||
|
||||
|
||||
# Strip the bot mention from the message text so the agent sees clean input
|
||||
if self._client.user and self._client.user in message.mentions:
|
||||
message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip()
|
||||
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if message.content.startswith("/"):
|
||||
@@ -796,7 +798,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
|
||||
# Determine chat type
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
@@ -809,15 +811,15 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_name = getattr(message.channel, "name", str(message.channel.id))
|
||||
if hasattr(message.channel, "guild") and message.channel.guild:
|
||||
chat_name = f"{message.channel.guild.name} / #{chat_name}"
|
||||
|
||||
|
||||
# Get thread ID if in a thread
|
||||
thread_id = None
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
thread_id = str(message.channel.id)
|
||||
|
||||
|
||||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't)
|
||||
chat_topic = getattr(message.channel, "topic", None)
|
||||
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(message.channel.id),
|
||||
@@ -828,7 +830,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
|
||||
# Build media URLs -- download image attachments to local cache so the
|
||||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||||
media_urls = []
|
||||
@@ -867,7 +869,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Other attachments: keep the original URL
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
|
||||
|
||||
event = MessageEvent(
|
||||
text=message.content,
|
||||
message_type=msg_type,
|
||||
@@ -879,7 +881,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
|
||||
timestamp=message.created_at,
|
||||
)
|
||||
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
@@ -909,14 +911,20 @@ if DISCORD_AVAILABLE:
|
||||
return True # No allowlist = anyone can approve
|
||||
return str(interaction.user.id) in self.allowed_user_ids
|
||||
|
||||
async def _resolve(self, interaction: discord.Interaction, action: str, color: discord.Color):
|
||||
async def _resolve(
|
||||
self, interaction: discord.Interaction, action: str, color: discord.Color
|
||||
):
|
||||
"""Resolve the approval and update the message."""
|
||||
if self.resolved:
|
||||
await interaction.response.send_message("This approval has already been resolved~", ephemeral=True)
|
||||
await interaction.response.send_message(
|
||||
"This approval has already been resolved~", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
if not self._check_auth(interaction):
|
||||
await interaction.response.send_message("You're not authorized to approve commands~", ephemeral=True)
|
||||
await interaction.response.send_message(
|
||||
"You're not authorized to approve commands~", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
self.resolved = True
|
||||
@@ -936,7 +944,6 @@ if DISCORD_AVAILABLE:
|
||||
# Store the approval decision
|
||||
try:
|
||||
from tools.approval import approve_permanent
|
||||
|
||||
if action == "allow_once":
|
||||
pass # One-time approval handled by gateway
|
||||
elif action == "allow_always":
|
||||
@@ -945,15 +952,21 @@ if DISCORD_AVAILABLE:
|
||||
pass
|
||||
|
||||
@discord.ui.button(label="Allow Once", style=discord.ButtonStyle.green)
|
||||
async def allow_once(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
async def allow_once(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
await self._resolve(interaction, "allow_once", discord.Color.green())
|
||||
|
||||
@discord.ui.button(label="Always Allow", style=discord.ButtonStyle.blurple)
|
||||
async def allow_always(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
async def allow_always(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
await self._resolve(interaction, "allow_always", discord.Color.blue())
|
||||
|
||||
@discord.ui.button(label="Deny", style=discord.ButtonStyle.red)
|
||||
async def deny(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
async def deny(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
await self._resolve(interaction, "deny", discord.Color.red())
|
||||
|
||||
async def on_timeout(self):
|
||||
|
||||
@@ -19,11 +19,10 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
@@ -67,10 +66,10 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
super().__init__(config, Platform.HOMEASSISTANT)
|
||||
|
||||
# Connection state
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._rest_session: aiohttp.ClientSession | None = None
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._session: Optional["aiohttp.ClientSession"] = None
|
||||
self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None
|
||||
self._rest_session: Optional["aiohttp.ClientSession"] = None
|
||||
self._listen_task: Optional[asyncio.Task] = None
|
||||
self._msg_id: int = 0
|
||||
|
||||
# Configuration from extra
|
||||
@@ -81,13 +80,13 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
self._hass_token: str = token
|
||||
|
||||
# Event filtering
|
||||
self._watch_domains: set[str] = set(extra.get("watch_domains", []))
|
||||
self._watch_entities: set[str] = set(extra.get("watch_entities", []))
|
||||
self._ignore_entities: set[str] = set(extra.get("ignore_entities", []))
|
||||
self._watch_domains: Set[str] = set(extra.get("watch_domains", []))
|
||||
self._watch_entities: Set[str] = set(extra.get("watch_entities", []))
|
||||
self._ignore_entities: Set[str] = set(extra.get("ignore_entities", []))
|
||||
self._cooldown_seconds: int = int(extra.get("cooldown_seconds", 30))
|
||||
|
||||
# Cooldown tracking: entity_id -> last_event_timestamp
|
||||
self._last_event_time: dict[str, float] = {}
|
||||
self._last_event_time: Dict[str, float] = {}
|
||||
|
||||
def _next_id(self) -> int:
|
||||
"""Return the next WebSocket message ID."""
|
||||
@@ -142,12 +141,10 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
# Step 2: Send auth
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"type": "auth",
|
||||
"access_token": self._hass_token,
|
||||
}
|
||||
)
|
||||
await self._ws.send_json({
|
||||
"type": "auth",
|
||||
"access_token": self._hass_token,
|
||||
})
|
||||
|
||||
# Step 3: Wait for auth_ok
|
||||
msg = await self._ws.receive_json()
|
||||
@@ -158,13 +155,11 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
|
||||
# Step 4: Subscribe to state_changed events
|
||||
sub_id = self._next_id()
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"id": sub_id,
|
||||
"type": "subscribe_events",
|
||||
"event_type": "state_changed",
|
||||
}
|
||||
)
|
||||
await self._ws.send_json({
|
||||
"id": sub_id,
|
||||
"type": "subscribe_events",
|
||||
"event_type": "state_changed",
|
||||
})
|
||||
|
||||
# Verify subscription acknowledgement
|
||||
msg = await self._ws.receive_json()
|
||||
@@ -250,7 +245,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
elif ws_msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
|
||||
break
|
||||
|
||||
async def _handle_ha_event(self, event: dict[str, Any]) -> None:
|
||||
async def _handle_ha_event(self, event: Dict[str, Any]) -> None:
|
||||
"""Process a state_changed event from Home Assistant."""
|
||||
event_data = event.get("data", {})
|
||||
entity_id: str = event_data.get("entity_id", "")
|
||||
@@ -307,9 +302,9 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
@staticmethod
|
||||
def _format_state_change(
|
||||
entity_id: str,
|
||||
old_state: dict[str, Any],
|
||||
new_state: dict[str, Any],
|
||||
) -> str | None:
|
||||
old_state: Dict[str, Any],
|
||||
new_state: Dict[str, Any],
|
||||
) -> Optional[str]:
|
||||
"""Convert a state_changed event into a human-readable description."""
|
||||
if not new_state:
|
||||
return None
|
||||
@@ -336,7 +331,10 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
|
||||
if domain == "sensor":
|
||||
unit = new_state.get("attributes", {}).get("unit_of_measurement", "")
|
||||
return f"[Home Assistant] {friendly_name}: changed from {old_val}{unit} to {new_val}{unit}"
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: changed from "
|
||||
f"{old_val}{unit} to {new_val}{unit}"
|
||||
)
|
||||
|
||||
if domain == "binary_sensor":
|
||||
return (
|
||||
@@ -346,13 +344,22 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
)
|
||||
|
||||
if domain in ("light", "switch", "fan"):
|
||||
return f"[Home Assistant] {friendly_name}: turned {'on' if new_val == 'on' else 'off'}"
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: turned "
|
||||
f"{'on' if new_val == 'on' else 'off'}"
|
||||
)
|
||||
|
||||
if domain == "alarm_control_panel":
|
||||
return f"[Home Assistant] {friendly_name}: alarm state changed from '{old_val}' to '{new_val}'"
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: alarm state changed from "
|
||||
f"'{old_val}' to '{new_val}'"
|
||||
)
|
||||
|
||||
# Generic fallback
|
||||
return f"[Home Assistant] {friendly_name} ({entity_id}): changed from '{old_val}' to '{new_val}'"
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name} ({entity_id}): "
|
||||
f"changed from '{old_val}' to '{new_val}'"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound messaging
|
||||
@@ -362,8 +369,8 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send a notification via HA REST API (persistent_notification.create).
|
||||
|
||||
@@ -377,7 +384,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
}
|
||||
payload = {
|
||||
"title": "Hermes Agent",
|
||||
"message": content[: self.MAX_MESSAGE_LENGTH],
|
||||
"message": content[:self.MAX_MESSAGE_LENGTH],
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -394,22 +401,20 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
else:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp,
|
||||
):
|
||||
if resp.status < 300:
|
||||
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
|
||||
else:
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
) as resp:
|
||||
if resp.status < 300:
|
||||
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
|
||||
else:
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
return SendResult(success=False, error="Timeout sending notification to HA")
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
@@ -418,7 +423,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
"""No typing indicator for Home Assistant."""
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Return basic info about the HA event channel."""
|
||||
return {
|
||||
"name": "Home Assistant Events",
|
||||
|
||||
@@ -19,9 +19,9 @@ import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict, List, Optional, Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
@@ -32,9 +32,9 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_bytes,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
@@ -59,7 +59,6 @@ _PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
|
||||
if not phone:
|
||||
@@ -69,7 +68,7 @@ def _redact_phone(phone: str) -> str:
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
|
||||
def _parse_comma_list(value: str) -> list[str]:
|
||||
def _parse_comma_list(value: str) -> List[str]:
|
||||
"""Split a comma-separated string into a list, stripping whitespace."""
|
||||
return [v.strip() for v in value.split(",") if v.strip()]
|
||||
|
||||
@@ -111,7 +110,7 @@ def _render_mentions(text: str, mentions: list) -> str:
|
||||
Signal encodes @mentions as the Unicode object replacement character
|
||||
with out-of-band metadata containing the mentioned user's UUID/number.
|
||||
"""
|
||||
if not mentions or "\ufffc" not in text:
|
||||
if not mentions or "\uFFFC" not in text:
|
||||
return text
|
||||
# Sort mentions by start position (reverse) to replace from end to start
|
||||
# so indices don't shift as we replace
|
||||
@@ -122,7 +121,7 @@ def _render_mentions(text: str, mentions: list) -> str:
|
||||
# Use the mention's number or UUID as the replacement
|
||||
identifier = mention.get("number") or mention.get("uuid") or "user"
|
||||
replacement = f"@{identifier}"
|
||||
text = text[:start] + replacement + text[start + length :]
|
||||
text = text[:start] + replacement + text[start + length:]
|
||||
return text
|
||||
|
||||
|
||||
@@ -135,7 +134,6 @@ def check_signal_requirements() -> bool:
|
||||
# Signal Adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SignalAdapter(BasePlatformAdapter):
|
||||
"""Signal messenger adapter using signal-cli HTTP daemon."""
|
||||
|
||||
@@ -154,25 +152,22 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self.group_allow_from = set(_parse_comma_list(group_allowed_str))
|
||||
|
||||
# HTTP client
|
||||
self.client: httpx.AsyncClient | None = None
|
||||
self.client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
# Background tasks
|
||||
self._sse_task: asyncio.Task | None = None
|
||||
self._health_monitor_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._sse_task: Optional[asyncio.Task] = None
|
||||
self._health_monitor_task: Optional[asyncio.Task] = None
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._running = False
|
||||
self._last_sse_activity = 0.0
|
||||
self._sse_response: httpx.Response | None = None
|
||||
self._sse_response: Optional[httpx.Response] = None
|
||||
|
||||
# Normalize account for self-message filtering
|
||||
self._account_normalized = self.account.strip()
|
||||
|
||||
logger.info(
|
||||
"Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url,
|
||||
_redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled",
|
||||
)
|
||||
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url, _redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
@@ -246,8 +241,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
logger.debug("Signal SSE: connecting to %s", url)
|
||||
async with self.client.stream(
|
||||
"GET",
|
||||
url,
|
||||
"GET", url,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
timeout=None,
|
||||
) as response:
|
||||
@@ -312,7 +306,9 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if elapsed > HEALTH_CHECK_STALE_THRESHOLD:
|
||||
logger.warning("Signal: SSE idle for %.0fs, checking daemon health", elapsed)
|
||||
try:
|
||||
resp = await self.client.get(f"{self.http_url}/api/v1/check", timeout=10.0)
|
||||
resp = await self.client.get(
|
||||
f"{self.http_url}/api/v1/check", timeout=10.0
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
# Daemon is alive but SSE is idle — update activity to
|
||||
# avoid repeated warnings (connection may just be quiet)
|
||||
@@ -349,7 +345,11 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
return
|
||||
|
||||
# Extract sender info
|
||||
sender = envelope_data.get("sourceNumber") or envelope_data.get("sourceUuid") or envelope_data.get("source")
|
||||
sender = (
|
||||
envelope_data.get("sourceNumber")
|
||||
or envelope_data.get("sourceUuid")
|
||||
or envelope_data.get("source")
|
||||
)
|
||||
sender_name = envelope_data.get("sourceName", "")
|
||||
sender_uuid = envelope_data.get("sourceUuid", "")
|
||||
|
||||
@@ -367,7 +367,10 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
# Get data message — also check editMessage (edited messages contain
|
||||
# their updated dataMessage inside editMessage.dataMessage)
|
||||
data_message = envelope_data.get("dataMessage") or (envelope_data.get("editMessage") or {}).get("dataMessage")
|
||||
data_message = (
|
||||
envelope_data.get("dataMessage")
|
||||
or (envelope_data.get("editMessage") or {}).get("dataMessage")
|
||||
)
|
||||
if not data_message:
|
||||
return
|
||||
|
||||
@@ -448,11 +451,11 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
ts_ms = envelope_data.get("timestamp", 0)
|
||||
if ts_ms:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=UTC)
|
||||
timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
|
||||
except (ValueError, OSError):
|
||||
timestamp = datetime.now(tz=UTC)
|
||||
timestamp = datetime.now(tz=timezone.utc)
|
||||
else:
|
||||
timestamp = datetime.now(tz=UTC)
|
||||
timestamp = datetime.now(tz=timezone.utc)
|
||||
|
||||
# Build and dispatch event
|
||||
event = MessageEvent(
|
||||
@@ -465,7 +468,8 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
logger.debug("Signal: message from %s in %s: %s", _redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
logger.debug("Signal: message from %s in %s: %s",
|
||||
_redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
@@ -475,13 +479,10 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _fetch_attachment(self, attachment_id: str) -> tuple:
|
||||
"""Fetch an attachment via JSON-RPC and cache it. Returns (path, ext)."""
|
||||
result = await self._rpc(
|
||||
"getAttachment",
|
||||
{
|
||||
"account": self.account,
|
||||
"attachmentId": attachment_id,
|
||||
},
|
||||
)
|
||||
result = await self._rpc("getAttachment", {
|
||||
"account": self.account,
|
||||
"attachmentId": attachment_id,
|
||||
})
|
||||
|
||||
if not result:
|
||||
return None, ""
|
||||
@@ -546,13 +547,13 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to_message_id: str | None = None,
|
||||
reply_to_message_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a text message."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
params: dict[str, Any] = {
|
||||
params: Dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": text,
|
||||
}
|
||||
@@ -570,7 +571,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send a typing indicator."""
|
||||
params: dict[str, Any] = {
|
||||
params: Dict[str, Any] = {
|
||||
"account": self.account,
|
||||
}
|
||||
|
||||
@@ -585,7 +586,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send an image. Supports http(s):// and file:// URLs."""
|
||||
@@ -610,7 +611,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if file_size > SIGNAL_MAX_ATTACHMENT_SIZE:
|
||||
return SendResult(success=False, error=f"Image too large ({file_size} bytes)")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
params: Dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": caption or "",
|
||||
"attachments": [file_path],
|
||||
@@ -630,8 +631,8 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: str | None = None,
|
||||
filename: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment."""
|
||||
@@ -640,7 +641,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if not Path(file_path).exists():
|
||||
return SendResult(success=False, error="File not found")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
params: Dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": caption or "",
|
||||
"attachments": [file_path],
|
||||
@@ -689,7 +690,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
# Chat Info
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a chat/contact."""
|
||||
if chat_id.startswith("group:"):
|
||||
return {
|
||||
@@ -699,13 +700,10 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
}
|
||||
|
||||
# Try to resolve contact name
|
||||
result = await self._rpc(
|
||||
"getContact",
|
||||
{
|
||||
"account": self.account,
|
||||
"contactAddress": chat_id,
|
||||
},
|
||||
)
|
||||
result = await self._rpc("getContact", {
|
||||
"account": self.account,
|
||||
"contactAddress": chat_id,
|
||||
})
|
||||
|
||||
name = chat_id
|
||||
if result and isinstance(result, dict):
|
||||
|
||||
@@ -10,14 +10,12 @@ Uses slack-bolt (Python) with Socket Mode for:
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
try:
|
||||
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
SLACK_AVAILABLE = True
|
||||
except ImportError:
|
||||
SLACK_AVAILABLE = False
|
||||
@@ -27,17 +25,16 @@ except ImportError:
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
)
|
||||
|
||||
|
||||
@@ -66,9 +63,9 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.SLACK)
|
||||
self._app: AsyncApp | None = None
|
||||
self._handler: AsyncSocketModeHandler | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
self._app: Optional[AsyncApp] = None
|
||||
self._handler: Optional[AsyncSocketModeHandler] = None
|
||||
self._bot_user_id: Optional[str] = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Slack via Socket Mode."""
|
||||
@@ -99,13 +96,6 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
async def handle_message_event(event, say):
|
||||
await self._handle_slack_message(event)
|
||||
|
||||
# Acknowledge app_mention events to prevent Bolt 404 errors.
|
||||
# The "message" handler above already processes @mentions in
|
||||
# channels, so this is intentionally a no-op to avoid duplicates.
|
||||
@self._app.event("app_mention")
|
||||
async def handle_app_mention(event, say):
|
||||
pass
|
||||
|
||||
# Register slash command handler
|
||||
@self._app.command("/hermes")
|
||||
async def handle_hermes_command(ack, command):
|
||||
@@ -135,8 +125,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send a message to a Slack channel or DM."""
|
||||
if not self._app:
|
||||
@@ -193,8 +183,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file to Slack by uploading it."""
|
||||
if not self._app:
|
||||
@@ -202,7 +192,6 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
@@ -223,8 +212,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image to Slack by uploading the URL as a file."""
|
||||
if not self._app:
|
||||
@@ -248,7 +237,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# Fall back to sending the URL as text
|
||||
text = f"{caption}\n{image_url}" if caption else image_url
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
@@ -257,8 +246,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an audio file to Slack."""
|
||||
if not self._app:
|
||||
@@ -277,66 +266,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a video file to Slack."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not os.path.exists(video_path):
|
||||
return SendResult(success=False, error=f"Video file not found: {video_path}")
|
||||
|
||||
try:
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=video_path,
|
||||
filename=os.path.basename(video_path),
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send video: {e}")
|
||||
return await super().send_video(chat_id, video_path, caption, reply_to)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment to Slack."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
display_name = file_name or os.path.basename(file_path)
|
||||
|
||||
try:
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=file_path,
|
||||
filename=display_name,
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send document: {e}")
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Slack channel."""
|
||||
if not self._app:
|
||||
return {"name": chat_id, "type": "unknown"}
|
||||
@@ -417,56 +347,6 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
msg_type = MessageType.VOICE
|
||||
except Exception as e:
|
||||
print(f"[Slack] Failed to cache audio: {e}", flush=True)
|
||||
elif url:
|
||||
# Try to handle as a document attachment
|
||||
try:
|
||||
original_filename = f.get("name", "")
|
||||
ext = ""
|
||||
if original_filename:
|
||||
_, ext = os.path.splitext(original_filename)
|
||||
ext = ext.lower()
|
||||
|
||||
# Fallback: reverse-lookup from MIME type
|
||||
if not ext and mimetype:
|
||||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||||
ext = mime_to_ext.get(mimetype, "")
|
||||
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
continue # Skip unsupported file types silently
|
||||
|
||||
# Check file size (Slack limit: 20 MB for bots)
|
||||
file_size = f.get("size", 0)
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if not file_size or file_size > MAX_DOC_BYTES:
|
||||
print(f"[Slack] Document too large or unknown size: {file_size}", flush=True)
|
||||
continue
|
||||
|
||||
# Download and cache
|
||||
raw_bytes = await self._download_slack_file_bytes(url)
|
||||
cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}")
|
||||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
msg_type = MessageType.DOCUMENT
|
||||
print(f"[Slack] Cached user document: {cached_path}", flush=True)
|
||||
|
||||
# Inject text content for .txt/.md files (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
display_name = re.sub(r"[^\w.\- ]", "_", display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if text:
|
||||
text = f"{injection}\n\n{text}"
|
||||
else:
|
||||
text = injection
|
||||
except UnicodeDecodeError:
|
||||
pass # Binary content, skip injection
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Slack] Failed to cache document: {e}", flush=True)
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
@@ -498,20 +378,16 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
# Map subcommands to gateway commands
|
||||
subcommand_map = {
|
||||
"new": "/reset",
|
||||
"reset": "/reset",
|
||||
"status": "/status",
|
||||
"stop": "/stop",
|
||||
"new": "/reset", "reset": "/reset",
|
||||
"status": "/status", "stop": "/stop",
|
||||
"help": "/help",
|
||||
"model": "/model",
|
||||
"personality": "/personality",
|
||||
"retry": "/retry",
|
||||
"undo": "/undo",
|
||||
"model": "/model", "personality": "/personality",
|
||||
"retry": "/retry", "undo": "/undo",
|
||||
}
|
||||
first_word = text.split()[0] if text else ""
|
||||
if first_word in subcommand_map:
|
||||
# Preserve arguments after the subcommand
|
||||
rest = text[len(first_word) :].strip()
|
||||
rest = text[len(first_word):].strip()
|
||||
text = f"{subcommand_map[first_word]} {rest}".strip() if rest else subcommand_map[first_word]
|
||||
elif text:
|
||||
pass # Treat as a regular question
|
||||
@@ -547,22 +423,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str) -> bytes:
|
||||
"""Download a Slack file and return raw bytes."""
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
@@ -7,26 +7,24 @@ Uses python-telegram-bot library for:
|
||||
- Handling media and commands
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from telegram import Bot, Message, Update
|
||||
from telegram.constants import ChatType, ParseMode
|
||||
from telegram import Update, Bot, Message
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
CommandHandler,
|
||||
MessageHandler as TelegramMessageHandler,
|
||||
ContextTypes,
|
||||
filters,
|
||||
)
|
||||
from telegram.ext import (
|
||||
MessageHandler as TelegramMessageHandler,
|
||||
)
|
||||
|
||||
from telegram.constants import ParseMode, ChatType
|
||||
TELEGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TELEGRAM_AVAILABLE = False
|
||||
@@ -44,24 +42,22 @@ except ImportError:
|
||||
# don't crash during class definition when the library isn't installed.
|
||||
class _MockContextTypes:
|
||||
DEFAULT_TYPE = Any
|
||||
|
||||
ContextTypes = _MockContextTypes
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
@@ -72,12 +68,12 @@ def check_telegram_requirements() -> bool:
|
||||
|
||||
# Matches every character that MarkdownV2 requires to be backslash-escaped
|
||||
# when it appears outside a code span or fenced code block.
|
||||
_MDV2_ESCAPE_RE = re.compile(r"([_*\[\]()~`>#\+\-=|{}.!\\])")
|
||||
_MDV2_ESCAPE_RE = re.compile(r'([_*\[\]()~`>#\+\-=|{}.!\\])')
|
||||
|
||||
|
||||
def _escape_mdv2(text: str) -> str:
|
||||
"""Escape Telegram MarkdownV2 special characters with a preceding backslash."""
|
||||
return _MDV2_ESCAPE_RE.sub(r"\\\1", text)
|
||||
return _MDV2_ESCAPE_RE.sub(r'\\\1', text)
|
||||
|
||||
|
||||
def _strip_mdv2(text: str) -> str:
|
||||
@@ -87,108 +83,103 @@ def _strip_mdv2(text: str) -> str:
|
||||
doesn't show stray asterisks from header/bold conversion.
|
||||
"""
|
||||
# Remove escape backslashes before special characters
|
||||
cleaned = re.sub(r"\\([_*\[\]()~`>#\+\-=|{}.!\\])", r"\1", text)
|
||||
cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text)
|
||||
# Remove MarkdownV2 bold markers that format_message converted from **bold**
|
||||
cleaned = re.sub(r"\*([^*]+)\*", r"\1", cleaned)
|
||||
cleaned = re.sub(r'\*([^*]+)\*', r'\1', cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
|
||||
|
||||
Handles:
|
||||
- Receiving messages from users and groups
|
||||
- Sending responses with Telegram markdown
|
||||
- Forum topics (thread_id support)
|
||||
- Media messages
|
||||
"""
|
||||
|
||||
|
||||
# Telegram message limits
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.TELEGRAM)
|
||||
self._app: Application | None = None
|
||||
self._bot: Bot | None = None
|
||||
|
||||
self._app: Optional[Application] = None
|
||||
self._bot: Optional[Bot] = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Telegram and start polling for updates."""
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
print(f"[{self.name}] python-telegram-bot not installed. Run: pip install python-telegram-bot")
|
||||
return False
|
||||
|
||||
|
||||
if not self.config.token:
|
||||
print(f"[{self.name}] No bot token configured")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Build the application
|
||||
self._app = Application.builder().token(self.config.token).build()
|
||||
self._bot = self._app.bot
|
||||
|
||||
|
||||
# Register handlers
|
||||
self._app.add_handler(TelegramMessageHandler(filters.TEXT & ~filters.COMMAND, self._handle_text_message))
|
||||
self._app.add_handler(TelegramMessageHandler(filters.COMMAND, self._handle_command))
|
||||
self._app.add_handler(
|
||||
TelegramMessageHandler(
|
||||
filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION), self._handle_location_message
|
||||
)
|
||||
)
|
||||
self._app.add_handler(
|
||||
TelegramMessageHandler(
|
||||
filters.PHOTO
|
||||
| filters.VIDEO
|
||||
| filters.AUDIO
|
||||
| filters.VOICE
|
||||
| filters.Document.ALL
|
||||
| filters.Sticker.ALL,
|
||||
self._handle_media_message,
|
||||
)
|
||||
)
|
||||
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.TEXT & ~filters.COMMAND,
|
||||
self._handle_text_message
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.COMMAND,
|
||||
self._handle_command
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION),
|
||||
self._handle_location_message
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.PHOTO | filters.VIDEO | filters.AUDIO | filters.VOICE | filters.Document.ALL | filters.Sticker.ALL,
|
||||
self._handle_media_message
|
||||
))
|
||||
|
||||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
|
||||
|
||||
|
||||
# Register bot commands so Telegram shows a hint menu when users type /
|
||||
try:
|
||||
from telegram import BotCommand
|
||||
|
||||
await self._bot.set_my_commands(
|
||||
[
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("reset", "Reset conversation history"),
|
||||
BotCommand("model", "Show or change the model"),
|
||||
BotCommand("personality", "Set a personality"),
|
||||
BotCommand("retry", "Retry your last message"),
|
||||
BotCommand("undo", "Remove the last exchange"),
|
||||
BotCommand("status", "Show session info"),
|
||||
BotCommand("stop", "Stop the running agent"),
|
||||
BotCommand("sethome", "Set this chat as the home channel"),
|
||||
BotCommand("compress", "Compress conversation context"),
|
||||
BotCommand("title", "Set or show the session title"),
|
||||
BotCommand("resume", "Resume a previously-named session"),
|
||||
BotCommand("usage", "Show token usage for this session"),
|
||||
BotCommand("provider", "Show available providers"),
|
||||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
)
|
||||
await self._bot.set_my_commands([
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("reset", "Reset conversation history"),
|
||||
BotCommand("model", "Show or change the model"),
|
||||
BotCommand("personality", "Set a personality"),
|
||||
BotCommand("retry", "Retry your last message"),
|
||||
BotCommand("undo", "Remove the last exchange"),
|
||||
BotCommand("status", "Show session info"),
|
||||
BotCommand("stop", "Stop the running agent"),
|
||||
BotCommand("sethome", "Set this chat as the home channel"),
|
||||
BotCommand("compress", "Compress conversation context"),
|
||||
BotCommand("title", "Set or show the session title"),
|
||||
BotCommand("resume", "Resume a previously-named session"),
|
||||
BotCommand("usage", "Show token usage for this session"),
|
||||
BotCommand("provider", "Show available providers"),
|
||||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
])
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Could not register command menu: {e}")
|
||||
|
||||
|
||||
self._running = True
|
||||
print(f"[{self.name}] Connected and polling for updates")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop polling and disconnect."""
|
||||
if self._app:
|
||||
@@ -198,27 +189,31 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
await self._app.shutdown()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during disconnect: {e}")
|
||||
|
||||
|
||||
self._running = False
|
||||
self._app = None
|
||||
self._bot = None
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
|
||||
async def send(
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Telegram chat."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
|
||||
message_ids = []
|
||||
thread_id = metadata.get("thread_id") if metadata else None
|
||||
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Try Markdown first, fall back to plain text if it fails
|
||||
try:
|
||||
@@ -232,9 +227,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning(
|
||||
"[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error
|
||||
)
|
||||
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
|
||||
# Strip MDV2 escape backslashes so the user doesn't
|
||||
# see raw backslashes littered through the message.
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
@@ -248,13 +241,13 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
raise # Re-raise if not a parse error
|
||||
message_ids.append(str(msg.message_id))
|
||||
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids},
|
||||
raw_response={"message_ids": message_ids}
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@@ -291,19 +284,18 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a native Telegram voice message or audio file."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
# .ogg files -> send as voice (round playable bubble)
|
||||
if audio_path.endswith(".ogg") or audio_path.endswith(".opus"):
|
||||
@@ -325,24 +317,23 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send voice/audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
@@ -359,17 +350,17 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Telegram photo.
|
||||
|
||||
|
||||
Tries URL-based send first (fast, works for <5MB images).
|
||||
Falls back to downloading and uploading as file (supports up to 10MB).
|
||||
"""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs (up to ~5MB)
|
||||
msg = await self._bot.send_photo(
|
||||
@@ -384,12 +375,11 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Fallback: download and upload as file (supports up to 10MB)
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(image_url)
|
||||
resp.raise_for_status()
|
||||
image_data = resp.content
|
||||
|
||||
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_data,
|
||||
@@ -401,18 +391,18 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
logger.error("[%s] File upload send_photo also failed: %s", self.name, e2)
|
||||
# Final fallback: send URL as text
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an animated GIF natively as a Telegram animation (auto-plays inline)."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
msg = await self._bot.send_animation(
|
||||
chat_id=int(chat_id),
|
||||
@@ -430,18 +420,21 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
"""Send typing indicator."""
|
||||
if self._bot:
|
||||
try:
|
||||
await self._bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
||||
await self._bot.send_chat_action(
|
||||
chat_id=int(chat_id),
|
||||
action="typing"
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Telegram chat."""
|
||||
if not self._bot:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
|
||||
try:
|
||||
chat = await self._bot.get_chat(int(chat_id))
|
||||
|
||||
|
||||
chat_type = "dm"
|
||||
if chat.type == ChatType.GROUP:
|
||||
chat_type = "group"
|
||||
@@ -451,7 +444,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
chat_type = "forum"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
|
||||
return {
|
||||
"name": chat.title or chat.full_name or str(chat_id),
|
||||
"type": chat_type,
|
||||
@@ -460,7 +453,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
}
|
||||
except Exception as e:
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Convert standard markdown to Telegram MarkdownV2 format.
|
||||
@@ -487,36 +480,38 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
|
||||
# 1) Protect fenced code blocks (``` ... ```)
|
||||
text = re.sub(
|
||||
r"(```(?:[^\n]*\n)?[\s\S]*?```)",
|
||||
r'(```(?:[^\n]*\n)?[\s\S]*?```)',
|
||||
lambda m: _ph(m.group(0)),
|
||||
text,
|
||||
)
|
||||
|
||||
# 2) Protect inline code (`...`)
|
||||
text = re.sub(r"(`[^`]+`)", lambda m: _ph(m.group(0)), text)
|
||||
text = re.sub(r'(`[^`]+`)', lambda m: _ph(m.group(0)), text)
|
||||
|
||||
# 3) Convert markdown links – escape the display text; inside the URL
|
||||
# only ')' and '\' need escaping per the MarkdownV2 spec.
|
||||
def _convert_link(m):
|
||||
display = _escape_mdv2(m.group(1))
|
||||
url = m.group(2).replace("\\", "\\\\").replace(")", "\\)")
|
||||
return _ph(f"[{display}]({url})")
|
||||
url = m.group(2).replace('\\', '\\\\').replace(')', '\\)')
|
||||
return _ph(f'[{display}]({url})')
|
||||
|
||||
text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _convert_link, text)
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', _convert_link, text)
|
||||
|
||||
# 4) Convert markdown headers (## Title) → bold *Title*
|
||||
def _convert_header(m):
|
||||
inner = m.group(1).strip()
|
||||
# Strip redundant bold markers that may appear inside a header
|
||||
inner = re.sub(r"\*\*(.+?)\*\*", r"\1", inner)
|
||||
return _ph(f"*{_escape_mdv2(inner)}*")
|
||||
inner = re.sub(r'\*\*(.+?)\*\*', r'\1', inner)
|
||||
return _ph(f'*{_escape_mdv2(inner)}*')
|
||||
|
||||
text = re.sub(r"^#{1,6}\s+(.+)$", _convert_header, text, flags=re.MULTILINE)
|
||||
text = re.sub(
|
||||
r'^#{1,6}\s+(.+)$', _convert_header, text, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# 5) Convert bold: **text** → *text* (MarkdownV2 bold)
|
||||
text = re.sub(
|
||||
r"\*\*(.+?)\*\*",
|
||||
lambda m: _ph(f"*{_escape_mdv2(m.group(1))}*"),
|
||||
r'\*\*(.+?)\*\*',
|
||||
lambda m: _ph(f'*{_escape_mdv2(m.group(1))}*'),
|
||||
text,
|
||||
)
|
||||
|
||||
@@ -524,8 +519,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# [^*\n]+ prevents matching across newlines (which would corrupt
|
||||
# bullet lists using * markers and multi-line content).
|
||||
text = re.sub(
|
||||
r"\*([^*\n]+)\*",
|
||||
lambda m: _ph(f"_{_escape_mdv2(m.group(1))}_"),
|
||||
r'\*([^*\n]+)\*',
|
||||
lambda m: _ph(f'_{_escape_mdv2(m.group(1))}_'),
|
||||
text,
|
||||
)
|
||||
|
||||
@@ -538,23 +533,23 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
text = text.replace(key, placeholders[key])
|
||||
|
||||
return text
|
||||
|
||||
|
||||
async def _handle_text_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming text messages."""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.TEXT)
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
async def _handle_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming command messages."""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.COMMAND)
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming location/venue pin messages."""
|
||||
if not update.message:
|
||||
@@ -594,9 +589,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
"""Handle incoming media messages, downloading images to local cache."""
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
|
||||
msg = update.message
|
||||
|
||||
|
||||
# Determine media type
|
||||
if msg.sticker:
|
||||
msg_type = MessageType.STICKER
|
||||
@@ -612,19 +607,19 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
msg_type = MessageType.DOCUMENT
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
|
||||
event = self._build_message_event(msg, msg_type)
|
||||
|
||||
|
||||
# Add caption as text
|
||||
if msg.caption:
|
||||
event.text = msg.caption
|
||||
|
||||
|
||||
# Handle stickers: describe via vision tool with caching
|
||||
if msg.sticker:
|
||||
await self._handle_sticker(msg, event)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
||||
|
||||
# Download photo to local image cache so the vision tool can access it
|
||||
# even after Telegram's ephemeral file URLs expire (~1 hour).
|
||||
if msg.photo:
|
||||
@@ -648,7 +643,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
print(f"[Telegram] Cached user photo: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache photo: {e}", flush=True)
|
||||
|
||||
|
||||
# Download voice/audio messages to cache for STT transcription
|
||||
if msg.voice:
|
||||
try:
|
||||
@@ -690,7 +685,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Check if supported
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
supported_list = ", ".join(sorted(SUPPORTED_DOCUMENT_TYPES.keys()))
|
||||
event.text = f"Unsupported document type '{ext or 'unknown'}'. Supported types: {supported_list}"
|
||||
event.text = (
|
||||
f"Unsupported document type '{ext or 'unknown'}'. "
|
||||
f"Supported types: {supported_list}"
|
||||
)
|
||||
print(f"[Telegram] Unsupported document type: {ext or 'unknown'}", flush=True)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
@@ -698,7 +696,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Check file size (Telegram Bot API limit: 20 MB)
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if not doc.file_size or doc.file_size > MAX_DOC_BYTES:
|
||||
event.text = "The document is too large or its size could not be verified. Maximum: 20 MB."
|
||||
event.text = (
|
||||
"The document is too large or its size could not be verified. "
|
||||
"Maximum: 20 MB."
|
||||
)
|
||||
print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
@@ -719,20 +720,20 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
display_name = re.sub(r"[^\w.\- ]", "_", display_name)
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if event.text:
|
||||
event.text = f"{injection}\n\n{event.text}"
|
||||
else:
|
||||
event.text = injection
|
||||
except UnicodeDecodeError:
|
||||
print("[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
|
||||
print(f"[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache document: {e}", flush=True)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None:
|
||||
"""
|
||||
Describe a Telegram sticker via vision analysis, with caching.
|
||||
@@ -742,11 +743,11 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
a placeholder noting the emoji.
|
||||
"""
|
||||
from gateway.sticker_cache import (
|
||||
STICKER_VISION_PROMPT,
|
||||
build_animated_sticker_injection,
|
||||
build_sticker_injection,
|
||||
cache_sticker_description,
|
||||
get_cached_description,
|
||||
cache_sticker_description,
|
||||
build_sticker_injection,
|
||||
build_animated_sticker_injection,
|
||||
STICKER_VISION_PROMPT,
|
||||
)
|
||||
|
||||
sticker = msg.sticker
|
||||
@@ -774,9 +775,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
cached_path = cache_image_from_bytes(bytes(image_bytes), ext=".webp")
|
||||
print(f"[Telegram] Analyzing sticker: {cached_path}", flush=True)
|
||||
|
||||
import json as _json
|
||||
|
||||
from tools.vision_tools import vision_analyze_tool
|
||||
import json as _json
|
||||
|
||||
result_json = await vision_analyze_tool(
|
||||
image_url=cached_path,
|
||||
@@ -792,29 +792,27 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Vision failed -- use emoji as fallback
|
||||
event.text = build_sticker_injection(
|
||||
f"a sticker with emoji {emoji}" if emoji else "a sticker",
|
||||
emoji,
|
||||
set_name,
|
||||
emoji, set_name,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Sticker analysis error: {e}", flush=True)
|
||||
event.text = build_sticker_injection(
|
||||
f"a sticker with emoji {emoji}" if emoji else "a sticker",
|
||||
emoji,
|
||||
set_name,
|
||||
emoji, set_name,
|
||||
)
|
||||
|
||||
def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Telegram message."""
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
|
||||
|
||||
# Determine chat type
|
||||
chat_type = "dm"
|
||||
if chat.type in (ChatType.GROUP, ChatType.SUPERGROUP):
|
||||
chat_type = "group"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(chat.id),
|
||||
@@ -824,7 +822,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
user_name=user.full_name if user else None,
|
||||
thread_id=str(message.message_thread_id) if message.message_thread_id else None,
|
||||
)
|
||||
|
||||
|
||||
return MessageEvent(
|
||||
text=message.text or "",
|
||||
message_type=msg_type,
|
||||
|
||||
@@ -16,6 +16,7 @@ with different backends via a bridge pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
@@ -23,7 +24,7 @@ import subprocess
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,9 +36,7 @@ def _kill_port_process(port: int) -> None:
|
||||
# Use netstat to find the PID bound to this port, then taskkill
|
||||
result = subprocess.run(
|
||||
["netstat", "-ano", "-p", "TCP"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
for line in result.stdout.splitlines():
|
||||
parts = line.split()
|
||||
@@ -47,29 +46,24 @@ def _kill_port_process(port: int) -> None:
|
||||
try:
|
||||
subprocess.run(
|
||||
["taskkill", "/PID", parts[4], "/F"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
except subprocess.SubprocessError:
|
||||
pass
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["fuser", f"{port}/tcp"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
subprocess.run(
|
||||
["fuser", "-k", f"{port}/tcp"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
@@ -78,20 +72,25 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_audio_from_url,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
)
|
||||
|
||||
|
||||
def check_whatsapp_requirements() -> bool:
|
||||
"""
|
||||
Check if WhatsApp dependencies are available.
|
||||
|
||||
|
||||
WhatsApp requires a Node.js bridge for most implementations.
|
||||
"""
|
||||
# Check for Node.js
|
||||
try:
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=5)
|
||||
result = subprocess.run(
|
||||
["node", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
@@ -100,61 +99,62 @@ def check_whatsapp_requirements() -> bool:
|
||||
class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
WhatsApp adapter.
|
||||
|
||||
|
||||
This implementation uses a simple HTTP bridge pattern where:
|
||||
1. A Node.js process runs the WhatsApp Web client
|
||||
2. Messages are forwarded via HTTP/IPC to this Python adapter
|
||||
3. Responses are sent back through the bridge
|
||||
|
||||
|
||||
The actual Node.js bridge implementation can vary:
|
||||
- whatsapp-web.js based
|
||||
- Baileys based
|
||||
- Business API based
|
||||
|
||||
|
||||
Configuration:
|
||||
- bridge_script: Path to the Node.js bridge script
|
||||
- bridge_port: Port for HTTP communication (default: 3000)
|
||||
- session_path: Path to store WhatsApp session data
|
||||
"""
|
||||
|
||||
|
||||
# WhatsApp message limits
|
||||
MAX_MESSAGE_LENGTH = 65536 # WhatsApp allows longer messages
|
||||
|
||||
|
||||
# Default bridge location relative to the hermes-agent install
|
||||
_DEFAULT_BRIDGE_DIR = Path(__file__).resolve().parents[2] / "scripts" / "whatsapp-bridge"
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.WHATSAPP)
|
||||
self._bridge_process: subprocess.Popen | None = None
|
||||
self._bridge_process: Optional[subprocess.Popen] = None
|
||||
self._bridge_port: int = config.extra.get("bridge_port", 3000)
|
||||
self._bridge_script: str | None = config.extra.get(
|
||||
self._bridge_script: Optional[str] = config.extra.get(
|
||||
"bridge_script",
|
||||
str(self._DEFAULT_BRIDGE_DIR / "bridge.js"),
|
||||
)
|
||||
self._session_path: Path = Path(
|
||||
config.extra.get("session_path", Path.home() / ".hermes" / "whatsapp" / "session")
|
||||
)
|
||||
self._session_path: Path = Path(config.extra.get(
|
||||
"session_path",
|
||||
Path.home() / ".hermes" / "whatsapp" / "session"
|
||||
))
|
||||
self._message_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._bridge_log_fh = None
|
||||
self._bridge_log: Path | None = None
|
||||
|
||||
self._bridge_log: Optional[Path] = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Start the WhatsApp bridge.
|
||||
|
||||
|
||||
This launches the Node.js bridge process and waits for it to be ready.
|
||||
"""
|
||||
if not check_whatsapp_requirements():
|
||||
logger.warning("[%s] Node.js not found. WhatsApp requires Node.js.", self.name)
|
||||
return False
|
||||
|
||||
|
||||
bridge_path = Path(self._bridge_script)
|
||||
if not bridge_path.exists():
|
||||
logger.warning("[%s] Bridge script not found: %s", self.name, bridge_path)
|
||||
return False
|
||||
|
||||
|
||||
logger.info("[%s] Bridge found at %s", self.name, bridge_path)
|
||||
|
||||
|
||||
# Auto-install npm dependencies if node_modules doesn't exist
|
||||
bridge_dir = bridge_path.parent
|
||||
if not (bridge_dir / "node_modules").exists():
|
||||
@@ -174,17 +174,16 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to install dependencies: {e}")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Ensure session directory exists
|
||||
self._session_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Kill any orphaned bridge from a previous gateway run
|
||||
_kill_port_process(self._bridge_port)
|
||||
import time
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
# Start the bridge process in its own process group.
|
||||
# Route output to a log file so QR codes, errors, and reconnection
|
||||
# messages are preserved for troubleshooting.
|
||||
@@ -196,23 +195,19 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
[
|
||||
"node",
|
||||
str(bridge_path),
|
||||
"--port",
|
||||
str(self._bridge_port),
|
||||
"--session",
|
||||
str(self._session_path),
|
||||
"--mode",
|
||||
whatsapp_mode,
|
||||
"--port", str(self._bridge_port),
|
||||
"--session", str(self._session_path),
|
||||
"--mode", whatsapp_mode,
|
||||
],
|
||||
stdout=bridge_log_fh,
|
||||
stderr=bridge_log_fh,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
|
||||
# Wait for the bridge to connect to WhatsApp.
|
||||
# Phase 1: wait for the HTTP server to come up (up to 15s).
|
||||
# Phase 2: wait for WhatsApp status: connected (up to 15s more).
|
||||
import aiohttp
|
||||
|
||||
http_ready = False
|
||||
data = {}
|
||||
for attempt in range(15):
|
||||
@@ -223,18 +218,17 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
http_ready = True
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
http_ready = True
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -243,7 +237,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Check log: {self._bridge_log}")
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
|
||||
|
||||
# Phase 2: HTTP is up but WhatsApp may still be connecting.
|
||||
# Give it more time to authenticate with saved credentials.
|
||||
if data.get("status") != "connected":
|
||||
@@ -256,17 +250,16 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
@@ -275,19 +268,19 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] ⚠ WhatsApp not connected after 30s")
|
||||
print(f"[{self.name}] Bridge log: {self._bridge_log}")
|
||||
print(f"[{self.name}] If session expired, re-pair: hermes whatsapp")
|
||||
|
||||
|
||||
# Start message polling task
|
||||
asyncio.create_task(self._poll_messages())
|
||||
|
||||
|
||||
self._running = True
|
||||
print(f"[{self.name}] Bridge started on port {self._bridge_port}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
|
||||
|
||||
def _close_bridge_log(self) -> None:
|
||||
"""Close the bridge log file handle if open."""
|
||||
if self._bridge_log_fh:
|
||||
@@ -303,7 +296,6 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
# Kill the entire process group so child node processes die too
|
||||
import signal
|
||||
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
self._bridge_process.terminate()
|
||||
@@ -322,25 +314,29 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._bridge_process.kill()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error stopping bridge: {e}")
|
||||
|
||||
|
||||
# Also kill any orphaned bridge processes on our port
|
||||
_kill_port_process(self._bridge_port)
|
||||
|
||||
|
||||
self._running = False
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
|
||||
async def send(
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> SendResult:
|
||||
"""Send a message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"chatId": chat_id,
|
||||
@@ -348,19 +344,28 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
}
|
||||
if reply_to:
|
||||
payload["replyTo"] = reply_to
|
||||
|
||||
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/send", json=payload, timeout=aiohttp.ClientTimeout(total=30)
|
||||
f"http://localhost:{self._bridge_port}/send",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(success=True, message_id=data.get("messageId"), raw_response=data)
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
|
||||
except ImportError:
|
||||
return SendResult(success=False, error="aiohttp not installed. Run: pip install aiohttp")
|
||||
return SendResult(
|
||||
success=False,
|
||||
error="aiohttp not installed. Run: pip install aiohttp"
|
||||
)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@@ -375,24 +380,21 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/edit",
|
||||
json={
|
||||
"chatId": chat_id,
|
||||
"messageId": message_id,
|
||||
"message": content,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
timeout=aiohttp.ClientTimeout(total=15)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@@ -401,8 +403,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
media_type: str,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send any media file via bridge /send-media endpoint."""
|
||||
if not self._running:
|
||||
@@ -413,7 +415,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
if not os.path.exists(file_path):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
payload: Dict[str, Any] = {
|
||||
"chatId": chat_id,
|
||||
"filePath": file_path,
|
||||
"mediaType": media_type,
|
||||
@@ -423,24 +425,22 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
if file_name:
|
||||
payload["fileName"] = file_name
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/send-media",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data,
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data,
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
@@ -449,8 +449,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Download image URL to cache, send natively via bridge."""
|
||||
try:
|
||||
@@ -463,8 +463,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively via bridge."""
|
||||
return await self._send_media_to_bridge(chat_id, image_path, "image", caption)
|
||||
@@ -473,8 +473,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a video natively via bridge — plays inline in WhatsApp."""
|
||||
return await self._send_media_to_bridge(chat_id, video_path, "video", caption)
|
||||
@@ -483,16 +483,13 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a document/file as a downloadable attachment via bridge."""
|
||||
return await self._send_media_to_bridge(
|
||||
chat_id,
|
||||
file_path,
|
||||
"document",
|
||||
caption,
|
||||
chat_id, file_path, "document", caption,
|
||||
file_name or os.path.basename(file_path),
|
||||
)
|
||||
|
||||
@@ -500,45 +497,44 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send typing indicator via bridge."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(
|
||||
f"http://localhost:{self._bridge_port}/typing",
|
||||
json={"chatId": chat_id},
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
timeout=aiohttp.ClientTimeout(total=5)
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a WhatsApp chat."""
|
||||
if not self._running:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/chat/{chat_id}", timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"name": data.get("name", chat_id),
|
||||
"type": "group" if data.get("isGroup") else "dm",
|
||||
"participants": data.get("participants", []),
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/chat/{chat_id}",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"name": data.get("name", chat_id),
|
||||
"type": "group" if data.get("isGroup") else "dm",
|
||||
"participants": data.get("participants", []),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Could not get WhatsApp chat info for %s: %s", chat_id, e)
|
||||
|
||||
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
|
||||
async def _poll_messages(self) -> None:
|
||||
"""Poll the bridge for incoming messages."""
|
||||
try:
|
||||
@@ -546,30 +542,29 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, message polling disabled")
|
||||
return
|
||||
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/messages", timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = await self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/messages",
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = await self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Poll error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
await asyncio.sleep(1) # Poll interval
|
||||
|
||||
async def _build_message_event(self, data: dict[str, Any]) -> MessageEvent | None:
|
||||
|
||||
async def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]:
|
||||
"""Build a MessageEvent from bridge message data, downloading images to cache."""
|
||||
try:
|
||||
# Determine message type
|
||||
@@ -584,11 +579,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
msg_type = MessageType.VOICE
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
|
||||
# Determine chat type
|
||||
is_group = data.get("isGroup", False)
|
||||
chat_type = "group" if is_group else "dm"
|
||||
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=data.get("chatId", ""),
|
||||
@@ -597,7 +592,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
user_id=data.get("senderId"),
|
||||
user_name=data.get("senderName"),
|
||||
)
|
||||
|
||||
|
||||
# Download image media URLs to the local cache so the vision tool
|
||||
# can access them reliably regardless of URL expiration.
|
||||
raw_urls = data.get("mediaUrls", [])
|
||||
@@ -627,7 +622,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
cached_urls.append(url)
|
||||
media_types.append("unknown")
|
||||
|
||||
|
||||
return MessageEvent(
|
||||
text=data.get("body", ""),
|
||||
message_type=msg_type,
|
||||
@@ -640,3 +635,4 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error building event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
854
gateway/run.py
854
gateway/run.py
File diff suppressed because it is too large
Load Diff
@@ -8,20 +8,22 @@ Handles:
|
||||
- Dynamic system prompt injection (agent knows its context)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .config import (
|
||||
GatewayConfig,
|
||||
HomeChannel,
|
||||
Platform,
|
||||
GatewayConfig,
|
||||
SessionResetPolicy,
|
||||
HomeChannel,
|
||||
)
|
||||
|
||||
|
||||
@@ -29,30 +31,29 @@ from .config import (
|
||||
class SessionSource:
|
||||
"""
|
||||
Describes where a message originated from.
|
||||
|
||||
|
||||
This information is used to:
|
||||
1. Route responses back to the right place
|
||||
2. Inject context into the system prompt
|
||||
3. Track origin for cron job delivery
|
||||
"""
|
||||
|
||||
platform: Platform
|
||||
chat_id: str
|
||||
chat_name: str | None = None
|
||||
chat_name: Optional[str] = None
|
||||
chat_type: str = "dm" # "dm", "group", "channel", "thread"
|
||||
user_id: str | None = None
|
||||
user_name: str | None = None
|
||||
thread_id: str | None = None # For forum topics, Discord threads, etc.
|
||||
chat_topic: str | None = None # Channel topic/description (Discord, Slack)
|
||||
user_id_alt: str | None = None # Signal UUID (alternative to phone number)
|
||||
chat_id_alt: str | None = None # Signal group internal ID
|
||||
|
||||
user_id: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
thread_id: Optional[str] = None # For forum topics, Discord threads, etc.
|
||||
chat_topic: Optional[str] = None # Channel topic/description (Discord, Slack)
|
||||
user_id_alt: Optional[str] = None # Signal UUID (alternative to phone number)
|
||||
chat_id_alt: Optional[str] = None # Signal group internal ID
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Human-readable description of the source."""
|
||||
if self.platform == Platform.LOCAL:
|
||||
return "CLI terminal"
|
||||
|
||||
|
||||
parts = []
|
||||
if self.chat_type == "dm":
|
||||
parts.append(f"DM with {self.user_name or self.user_id or 'user'}")
|
||||
@@ -62,13 +63,13 @@ class SessionSource:
|
||||
parts.append(f"channel: {self.chat_name or self.chat_id}")
|
||||
else:
|
||||
parts.append(self.chat_name or self.chat_id)
|
||||
|
||||
|
||||
if self.thread_id:
|
||||
parts.append(f"thread: {self.thread_id}")
|
||||
|
||||
|
||||
return ", ".join(parts)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {
|
||||
"platform": self.platform.value,
|
||||
"chat_id": self.chat_id,
|
||||
@@ -84,9 +85,9 @@ class SessionSource:
|
||||
if self.chat_id_alt:
|
||||
d["chat_id_alt"] = self.chat_id_alt
|
||||
return d
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionSource":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionSource":
|
||||
return cls(
|
||||
platform=Platform(data["platform"]),
|
||||
chat_id=str(data["chat_id"]),
|
||||
@@ -99,7 +100,7 @@ class SessionSource:
|
||||
user_id_alt=data.get("user_id_alt"),
|
||||
chat_id_alt=data.get("chat_id_alt"),
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def local_cli(cls) -> "SessionSource":
|
||||
"""Create a source representing the local CLI."""
|
||||
@@ -115,28 +116,29 @@ class SessionSource:
|
||||
class SessionContext:
|
||||
"""
|
||||
Full context for a session, used for dynamic system prompt injection.
|
||||
|
||||
|
||||
The agent receives this information to understand:
|
||||
- Where messages are coming from
|
||||
- What platforms are available
|
||||
- Where it can deliver scheduled task outputs
|
||||
"""
|
||||
|
||||
source: SessionSource
|
||||
connected_platforms: list[Platform]
|
||||
home_channels: dict[Platform, HomeChannel]
|
||||
|
||||
connected_platforms: List[Platform]
|
||||
home_channels: Dict[Platform, HomeChannel]
|
||||
|
||||
# Session metadata
|
||||
session_key: str = ""
|
||||
session_id: str = ""
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"source": self.source.to_dict(),
|
||||
"connected_platforms": [p.value for p in self.connected_platforms],
|
||||
"home_channels": {p.value: hc.to_dict() for p, hc in self.home_channels.items()},
|
||||
"home_channels": {
|
||||
p.value: hc.to_dict() for p, hc in self.home_channels.items()
|
||||
},
|
||||
"session_key": self.session_key,
|
||||
"session_id": self.session_id,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
@@ -147,7 +149,7 @@ class SessionContext:
|
||||
def build_session_context_prompt(context: SessionContext) -> str:
|
||||
"""
|
||||
Build the dynamic system prompt section that tells the agent about its context.
|
||||
|
||||
|
||||
This is injected into the system prompt so the agent knows:
|
||||
- Where messages are coming from
|
||||
- What platforms are connected
|
||||
@@ -157,14 +159,14 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
"## Current Session Context",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
# Source info
|
||||
platform_name = context.source.platform.value.title()
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append(f"**Source:** {platform_name} (the machine running this agent)")
|
||||
else:
|
||||
lines.append(f"**Source:** {platform_name} ({context.source.description})")
|
||||
|
||||
|
||||
# Channel topic (if available - provides context about the channel's purpose)
|
||||
if context.source.chat_topic:
|
||||
lines.append(f"**Channel Topic:** {context.source.chat_topic}")
|
||||
@@ -174,43 +176,43 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
lines.append(f"**User:** {context.source.user_name}")
|
||||
elif context.source.user_id:
|
||||
lines.append(f"**User ID:** {context.source.user_id}")
|
||||
|
||||
|
||||
# Connected platforms
|
||||
platforms_list = ["local (files on this machine)"]
|
||||
for p in context.connected_platforms:
|
||||
if p != Platform.LOCAL:
|
||||
platforms_list.append(f"{p.value}: Connected ✓")
|
||||
|
||||
|
||||
lines.append(f"**Connected Platforms:** {', '.join(platforms_list)}")
|
||||
|
||||
|
||||
# Home channels
|
||||
if context.home_channels:
|
||||
lines.append("")
|
||||
lines.append("**Home Channels (default destinations):**")
|
||||
for platform, home in context.home_channels.items():
|
||||
lines.append(f" - {platform.value}: {home.name} (ID: {home.chat_id})")
|
||||
|
||||
|
||||
# Delivery options for scheduled tasks
|
||||
lines.append("")
|
||||
lines.append("**Delivery options for scheduled tasks:**")
|
||||
|
||||
|
||||
# Origin delivery
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append('- `"origin"` → Local output (saved to files)')
|
||||
lines.append("- `\"origin\"` → Local output (saved to files)")
|
||||
else:
|
||||
lines.append(f'- `"origin"` → Back to this chat ({context.source.chat_name or context.source.chat_id})')
|
||||
|
||||
lines.append(f"- `\"origin\"` → Back to this chat ({context.source.chat_name or context.source.chat_id})")
|
||||
|
||||
# Local always available
|
||||
lines.append('- `"local"` → Save to local files only (~/.hermes/cron/output/)')
|
||||
|
||||
lines.append("- `\"local\"` → Save to local files only (~/.hermes/cron/output/)")
|
||||
|
||||
# Platform home channels
|
||||
for platform, home in context.home_channels.items():
|
||||
lines.append(f'- `"{platform.value}"` → Home channel ({home.name})')
|
||||
|
||||
lines.append(f"- `\"{platform.value}\"` → Home channel ({home.name})")
|
||||
|
||||
# Note about explicit targeting
|
||||
lines.append("")
|
||||
lines.append('*For explicit targeting, use `"platform:chat_id"` format if the user provides a specific chat ID.*')
|
||||
|
||||
lines.append("*For explicit targeting, use `\"platform:chat_id\"` format if the user provides a specific chat ID.*")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -218,33 +220,32 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
class SessionEntry:
|
||||
"""
|
||||
Entry in the session store.
|
||||
|
||||
|
||||
Maps a session key to its current session ID and metadata.
|
||||
"""
|
||||
|
||||
session_key: str
|
||||
session_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# Origin metadata for delivery routing
|
||||
origin: SessionSource | None = None
|
||||
|
||||
origin: Optional[SessionSource] = None
|
||||
|
||||
# Display metadata
|
||||
display_name: str | None = None
|
||||
platform: Platform | None = None
|
||||
display_name: Optional[str] = None
|
||||
platform: Optional[Platform] = None
|
||||
chat_type: str = "dm"
|
||||
|
||||
|
||||
# Token tracking
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
# Set when a session was created because the previous one expired;
|
||||
# consumed once by the message handler to inject a notice into context
|
||||
was_auto_reset: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"session_key": self.session_key,
|
||||
"session_id": self.session_id,
|
||||
@@ -260,20 +261,20 @@ class SessionEntry:
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
return result
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionEntry":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionEntry":
|
||||
origin = None
|
||||
if "origin" in data and data["origin"]:
|
||||
origin = SessionSource.from_dict(data["origin"])
|
||||
|
||||
|
||||
platform = None
|
||||
if data.get("platform"):
|
||||
try:
|
||||
platform = Platform(data["platform"])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
return cls(
|
||||
session_key=data["session_key"],
|
||||
session_id=data["session_id"],
|
||||
@@ -306,65 +307,66 @@ def build_session_key(source: SessionSource) -> str:
|
||||
class SessionStore:
|
||||
"""
|
||||
Manages session storage and retrieval.
|
||||
|
||||
|
||||
Uses SQLite (via SessionDB) for session metadata and message transcripts.
|
||||
Falls back to legacy JSONL files if SQLite is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, sessions_dir: Path, config: GatewayConfig, has_active_processes_fn=None, on_auto_reset=None):
|
||||
|
||||
def __init__(self, sessions_dir: Path, config: GatewayConfig,
|
||||
has_active_processes_fn=None,
|
||||
on_auto_reset=None):
|
||||
self.sessions_dir = sessions_dir
|
||||
self.config = config
|
||||
self._entries: dict[str, SessionEntry] = {}
|
||||
self._entries: Dict[str, SessionEntry] = {}
|
||||
self._loaded = False
|
||||
self._has_active_processes_fn = has_active_processes_fn
|
||||
# on_auto_reset is deprecated — memory flush now runs proactively
|
||||
# via the background session expiry watcher in GatewayRunner.
|
||||
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
|
||||
|
||||
|
||||
# Initialize SQLite session database
|
||||
self._db = None
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
|
||||
self._db = SessionDB()
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: SQLite session store unavailable, falling back to JSONL: {e}")
|
||||
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
"""Load sessions index from disk if not already loaded."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
|
||||
if sessions_file.exists():
|
||||
try:
|
||||
with open(sessions_file, encoding="utf-8") as f:
|
||||
with open(sessions_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for key, entry_data in data.items():
|
||||
self._entries[key] = SessionEntry.from_dict(entry_data)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load sessions: {e}")
|
||||
|
||||
|
||||
self._loaded = True
|
||||
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Save sessions index to disk (kept for session key -> ID mapping)."""
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
|
||||
data = {key: entry.to_dict() for key, entry in self._entries.items()}
|
||||
with open(sessions_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def _generate_session_key(self, source: SessionSource) -> str:
|
||||
"""Generate a session key from a source."""
|
||||
return build_session_key(source)
|
||||
|
||||
|
||||
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
||||
"""Check if a session has expired based on its reset policy.
|
||||
|
||||
|
||||
Works from the entry alone — no SessionSource needed.
|
||||
Used by the background expiry watcher to proactively flush memories.
|
||||
Sessions with active background processes are never considered expired.
|
||||
@@ -391,9 +393,7 @@ class SessionStore:
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(
|
||||
hour=policy.at_hour,
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
minute=0, second=0, microsecond=0,
|
||||
)
|
||||
if now.hour < policy.at_hour:
|
||||
today_reset -= timedelta(days=1)
|
||||
@@ -405,7 +405,7 @@ class SessionStore:
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||
"""
|
||||
Check if a session should be reset based on policy.
|
||||
|
||||
|
||||
Sessions with active background processes are never reset.
|
||||
"""
|
||||
if self._has_active_processes_fn:
|
||||
@@ -413,28 +413,36 @@ class SessionStore:
|
||||
if self._has_active_processes_fn(session_key):
|
||||
return False
|
||||
|
||||
policy = self.config.get_reset_policy(platform=source.platform, session_type=source.chat_type)
|
||||
|
||||
policy = self.config.get_reset_policy(
|
||||
platform=source.platform,
|
||||
session_type=source.chat_type
|
||||
)
|
||||
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
|
||||
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(hour=policy.at_hour, minute=0, second=0, microsecond=0)
|
||||
today_reset = now.replace(
|
||||
hour=policy.at_hour,
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=0
|
||||
)
|
||||
if now.hour < policy.at_hour:
|
||||
today_reset -= timedelta(days=1)
|
||||
|
||||
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def has_any_sessions(self) -> bool:
|
||||
"""Check if any sessions have ever been created (across all platforms).
|
||||
|
||||
@@ -455,22 +463,26 @@ class SessionStore:
|
||||
# This covers the rare case where the DB is unavailable.
|
||||
self._ensure_loaded()
|
||||
return len(self._entries) > 1
|
||||
|
||||
def get_or_create_session(self, source: SessionSource, force_new: bool = False) -> SessionEntry:
|
||||
|
||||
def get_or_create_session(
|
||||
self,
|
||||
source: SessionSource,
|
||||
force_new: bool = False
|
||||
) -> SessionEntry:
|
||||
"""
|
||||
Get an existing session or create a new one.
|
||||
|
||||
|
||||
Evaluates reset policy to determine if the existing session is stale.
|
||||
Creates a session record in SQLite when a new session starts.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
|
||||
session_key = self._generate_session_key(source)
|
||||
now = datetime.now()
|
||||
|
||||
|
||||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
|
||||
if not self._should_reset(entry, source):
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
@@ -488,10 +500,10 @@ class SessionStore:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
else:
|
||||
was_auto_reset = False
|
||||
|
||||
|
||||
# Create new session
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
@@ -503,10 +515,10 @@ class SessionStore:
|
||||
chat_type=source.chat_type,
|
||||
was_auto_reset=was_auto_reset,
|
||||
)
|
||||
|
||||
|
||||
self._entries[session_key] = entry
|
||||
self._save()
|
||||
|
||||
|
||||
# Create session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
@@ -517,13 +529,18 @@ class SessionStore:
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to create SQLite session: {e}")
|
||||
|
||||
|
||||
return entry
|
||||
|
||||
def update_session(self, session_key: str, input_tokens: int = 0, output_tokens: int = 0) -> None:
|
||||
|
||||
def update_session(
|
||||
self,
|
||||
session_key: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0
|
||||
) -> None:
|
||||
"""Update a session's metadata after an interaction."""
|
||||
self._ensure_loaded()
|
||||
|
||||
|
||||
if session_key in self._entries:
|
||||
entry = self._entries[session_key]
|
||||
entry.updated_at = datetime.now()
|
||||
@@ -531,32 +548,34 @@ class SessionStore:
|
||||
entry.output_tokens += output_tokens
|
||||
entry.total_tokens = entry.input_tokens + entry.output_tokens
|
||||
self._save()
|
||||
|
||||
|
||||
if self._db:
|
||||
try:
|
||||
self._db.update_token_counts(entry.session_id, input_tokens, output_tokens)
|
||||
self._db.update_token_counts(
|
||||
entry.session_id, input_tokens, output_tokens
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
def reset_session(self, session_key: str) -> SessionEntry | None:
|
||||
|
||||
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
||||
"""Force reset a session, creating a new session ID."""
|
||||
self._ensure_loaded()
|
||||
|
||||
|
||||
if session_key not in self._entries:
|
||||
return None
|
||||
|
||||
|
||||
old_entry = self._entries[session_key]
|
||||
|
||||
|
||||
# End old session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
self._db.end_session(old_entry.session_id, "session_reset")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
|
||||
now = datetime.now()
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
new_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
@@ -567,10 +586,10 @@ class SessionStore:
|
||||
platform=old_entry.platform,
|
||||
chat_type=old_entry.chat_type,
|
||||
)
|
||||
|
||||
|
||||
self._entries[session_key] = new_entry
|
||||
self._save()
|
||||
|
||||
|
||||
# Create new session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
@@ -581,10 +600,10 @@ class SessionStore:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
|
||||
return new_entry
|
||||
|
||||
def switch_session(self, session_key: str, target_session_id: str) -> SessionEntry | None:
|
||||
def switch_session(self, session_key: str, target_session_id: str) -> Optional[SessionEntry]:
|
||||
"""Switch a session key to point at an existing session ID.
|
||||
|
||||
Used by ``/resume`` to restore a previously-named session.
|
||||
@@ -626,25 +645,25 @@ class SessionStore:
|
||||
self._save()
|
||||
return new_entry
|
||||
|
||||
def list_sessions(self, active_minutes: int | None = None) -> list[SessionEntry]:
|
||||
def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]:
|
||||
"""List all sessions, optionally filtered by activity."""
|
||||
self._ensure_loaded()
|
||||
|
||||
|
||||
entries = list(self._entries.values())
|
||||
|
||||
|
||||
if active_minutes is not None:
|
||||
cutoff = datetime.now() - timedelta(minutes=active_minutes)
|
||||
entries = [e for e in entries if e.updated_at >= cutoff]
|
||||
|
||||
|
||||
entries.sort(key=lambda e: e.updated_at, reverse=True)
|
||||
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def get_transcript_path(self, session_id: str) -> Path:
|
||||
"""Get the path to a session's legacy transcript file."""
|
||||
return self.sessions_dir / f"{session_id}.jsonl"
|
||||
|
||||
def append_to_transcript(self, session_id: str, message: dict[str, Any]) -> None:
|
||||
|
||||
def append_to_transcript(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
"""Append a message to a session's transcript (SQLite + legacy JSONL)."""
|
||||
# Write to SQLite
|
||||
if self._db:
|
||||
@@ -659,15 +678,15 @@ class SessionStore:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
|
||||
# Also write legacy JSONL (keeps existing tooling working during transition)
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
|
||||
def rewrite_transcript(self, session_id: str, messages: list[dict[str, Any]]) -> None:
|
||||
|
||||
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Replace the entire transcript for a session with new messages.
|
||||
|
||||
|
||||
Used by /retry, /undo, and /compress to persist modified conversation history.
|
||||
Rewrites both SQLite and legacy JSONL storage.
|
||||
"""
|
||||
@@ -686,14 +705,14 @@ class SessionStore:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
||||
|
||||
|
||||
# JSONL: overwrite the file
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "w", encoding="utf-8") as f:
|
||||
for msg in messages:
|
||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||
|
||||
def load_transcript(self, session_id: str) -> list[dict[str, Any]]:
|
||||
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""Load all messages from a session's transcript."""
|
||||
# Try SQLite first
|
||||
if self._db:
|
||||
@@ -703,49 +722,51 @@ class SessionStore:
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.debug("Could not load messages from DB: %s", e)
|
||||
|
||||
|
||||
# Fall back to legacy JSONL
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
|
||||
|
||||
if not transcript_path.exists():
|
||||
return []
|
||||
|
||||
|
||||
messages = []
|
||||
with open(transcript_path, encoding="utf-8") as f:
|
||||
with open(transcript_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
messages.append(json.loads(line))
|
||||
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_session_context(
|
||||
source: SessionSource, config: GatewayConfig, session_entry: SessionEntry | None = None
|
||||
source: SessionSource,
|
||||
config: GatewayConfig,
|
||||
session_entry: Optional[SessionEntry] = None
|
||||
) -> SessionContext:
|
||||
"""
|
||||
Build a full session context from a source and config.
|
||||
|
||||
|
||||
This is used to inject context into the agent's system prompt.
|
||||
"""
|
||||
connected = config.get_connected_platforms()
|
||||
|
||||
|
||||
home_channels = {}
|
||||
for platform in connected:
|
||||
home = config.get_home_channel(platform)
|
||||
if home:
|
||||
home_channels[platform] = home
|
||||
|
||||
|
||||
context = SessionContext(
|
||||
source=source,
|
||||
connected_platforms=connected,
|
||||
home_channels=home_channels,
|
||||
)
|
||||
|
||||
|
||||
if session_entry:
|
||||
context.session_key = session_entry.session_key
|
||||
context.session_id = session_entry.session_id
|
||||
context.created_at = session_entry.created_at
|
||||
context.updated_at = session_entry.updated_at
|
||||
|
||||
|
||||
return context
|
||||
|
||||
@@ -13,6 +13,7 @@ concurrently under distinct configurations).
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _get_pid_path() -> Path:
|
||||
@@ -36,7 +37,7 @@ def remove_pid_file() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def get_running_pid() -> int | None:
|
||||
def get_running_pid() -> Optional[int]:
|
||||
"""Return the PID of a running gateway instance, or ``None``.
|
||||
|
||||
Checks the PID file and verifies the process is actually alive.
|
||||
|
||||
@@ -12,6 +12,8 @@ import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
CACHE_PATH = Path(os.path.expanduser("~/.hermes/sticker_cache.json"))
|
||||
|
||||
@@ -41,7 +43,7 @@ def _save_cache(cache: dict) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_cached_description(file_unique_id: str) -> dict | None:
|
||||
def get_cached_description(file_unique_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Look up a cached sticker description.
|
||||
|
||||
@@ -90,11 +92,11 @@ def build_sticker_injection(
|
||||
"""
|
||||
context = ""
|
||||
if set_name and emoji:
|
||||
context = f' {emoji} from "{set_name}"'
|
||||
context = f" {emoji} from \"{set_name}\""
|
||||
elif emoji:
|
||||
context = f" {emoji}"
|
||||
|
||||
return f'[The user sent a sticker{context}~ It shows: "{description}" (=^.w.^=)]'
|
||||
return f"[The user sent a sticker{context}~ It shows: \"{description}\" (=^.w.^=)]"
|
||||
|
||||
|
||||
def build_animated_sticker_injection(emoji: str = "") -> str:
|
||||
|
||||
@@ -5,7 +5,7 @@ Provides subcommands for:
|
||||
- hermes chat - Interactive chat (same as ./hermes)
|
||||
- hermes gateway - Run gateway in foreground
|
||||
- hermes gateway start - Start gateway service
|
||||
- hermes gateway stop - Stop gateway service
|
||||
- hermes gateway stop - Stop gateway service
|
||||
- hermes setup - Interactive setup wizard
|
||||
- hermes status - Show status of all components
|
||||
- hermes cron - Manage cron jobs
|
||||
|
||||
@@ -15,25 +15,27 @@ Architecture:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
import base64
|
||||
import hashlib
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
import webbrowser
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from hermes_cli.config import get_config_path, get_hermes_home
|
||||
from hermes_cli.config import get_hermes_home, get_config_path
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,8 +58,8 @@ DEFAULT_NOUS_INFERENCE_URL = "https://inference-api.nousresearch.com/v1"
|
||||
DEFAULT_NOUS_CLIENT_ID = "hermes-cli"
|
||||
DEFAULT_NOUS_SCOPE = "inference:mint_agent_key"
|
||||
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
|
||||
DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s
|
||||
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
|
||||
DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s
|
||||
DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
@@ -68,11 +70,9 @@ CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
# Provider Registry
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
"""Describes a known inference provider."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
auth_type: str # "oauth_device_code", "oauth_external", or "api_key"
|
||||
@@ -80,14 +80,14 @@ class ProviderConfig:
|
||||
inference_base_url: str = ""
|
||||
client_id: str = ""
|
||||
scope: str = ""
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
# For API-key providers: env vars to check (in priority order)
|
||||
api_key_env_vars: tuple = ()
|
||||
# Optional env var for base URL override
|
||||
base_url_env_var: str = ""
|
||||
|
||||
|
||||
PROVIDER_REGISTRY: dict[str, ProviderConfig] = {
|
||||
PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
"nous": ProviderConfig(
|
||||
id="nous",
|
||||
name="Nous Portal",
|
||||
@@ -172,14 +172,14 @@ def _resolve_kimi_base_url(api_key: str, default_url: str, env_override: str) ->
|
||||
|
||||
ZAI_ENDPOINTS = [
|
||||
# (id, base_url, default_model, label)
|
||||
("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"),
|
||||
("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"),
|
||||
("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"),
|
||||
("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"),
|
||||
("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"),
|
||||
("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"),
|
||||
("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"),
|
||||
("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"),
|
||||
]
|
||||
|
||||
|
||||
def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> dict[str, str] | None:
|
||||
def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str, str]]:
|
||||
"""Probe z.ai endpoints to find one that accepts this API key.
|
||||
|
||||
Returns {"id": ..., "base_url": ..., "model": ..., "label": ...} for the
|
||||
@@ -219,7 +219,6 @@ def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> dict[str, str] |
|
||||
# Error Types
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AuthError(RuntimeError):
|
||||
"""Structured auth error with UX mapping hints."""
|
||||
|
||||
@@ -228,7 +227,7 @@ class AuthError(RuntimeError):
|
||||
message: str,
|
||||
*,
|
||||
provider: str = "",
|
||||
code: str | None = None,
|
||||
code: Optional[str] = None,
|
||||
relogin_required: bool = False,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
@@ -246,10 +245,16 @@ def format_auth_error(error: Exception) -> str:
|
||||
return f"{error} Run `hermes model` to re-authenticate."
|
||||
|
||||
if error.code == "subscription_required":
|
||||
return "No active paid subscription found on Nous Portal. Please purchase/activate a subscription, then retry."
|
||||
return (
|
||||
"No active paid subscription found on Nous Portal. "
|
||||
"Please purchase/activate a subscription, then retry."
|
||||
)
|
||||
|
||||
if error.code == "insufficient_credits":
|
||||
return "Subscription credits are exhausted. Top up/renew credits in Nous Portal, then retry."
|
||||
return (
|
||||
"Subscription credits are exhausted. "
|
||||
"Top up/renew credits in Nous Portal, then retry."
|
||||
)
|
||||
|
||||
if error.code == "temporarily_unavailable":
|
||||
return f"{error} Please retry in a few seconds."
|
||||
@@ -257,7 +262,7 @@ def format_auth_error(error: Exception) -> str:
|
||||
return str(error)
|
||||
|
||||
|
||||
def _token_fingerprint(token: Any) -> str | None:
|
||||
def _token_fingerprint(token: Any) -> Optional[str]:
|
||||
"""Return a short hash fingerprint for telemetry without leaking token bytes."""
|
||||
if not isinstance(token, str):
|
||||
return None
|
||||
@@ -272,10 +277,10 @@ def _oauth_trace_enabled() -> bool:
|
||||
return raw in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _oauth_trace(event: str, *, sequence_id: str | None = None, **fields: Any) -> None:
|
||||
def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any) -> None:
|
||||
if not _oauth_trace_enabled():
|
||||
return
|
||||
payload: dict[str, Any] = {"event": event}
|
||||
payload: Dict[str, Any] = {"event": event}
|
||||
if sequence_id:
|
||||
payload["sequence_id"] = sequence_id
|
||||
payload.update(fields)
|
||||
@@ -286,7 +291,6 @@ def _oauth_trace(event: str, *, sequence_id: str | None = None, **fields: Any) -
|
||||
# Auth Store — persistence layer for ~/.hermes/auth.json
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _auth_file_path() -> Path:
|
||||
return get_hermes_home() / "auth.json"
|
||||
|
||||
@@ -322,7 +326,7 @@ def _auth_store_lock(timeout_seconds: float = AUTH_LOCK_TIMEOUT_SECONDS):
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
|
||||
def _load_auth_store(auth_file: Path | None = None) -> dict[str, Any]:
|
||||
def _load_auth_store(auth_file: Optional[Path] = None) -> Dict[str, Any]:
|
||||
auth_file = auth_file or _auth_file_path()
|
||||
if not auth_file.exists():
|
||||
return {"version": AUTH_STORE_VERSION, "providers": {}}
|
||||
@@ -341,16 +345,17 @@ def _load_auth_store(auth_file: Path | None = None) -> dict[str, Any]:
|
||||
providers = {}
|
||||
if "nous_portal" in systems:
|
||||
providers["nous"] = systems["nous_portal"]
|
||||
return {"version": AUTH_STORE_VERSION, "providers": providers, "active_provider": "nous" if providers else None}
|
||||
return {"version": AUTH_STORE_VERSION, "providers": providers,
|
||||
"active_provider": "nous" if providers else None}
|
||||
|
||||
return {"version": AUTH_STORE_VERSION, "providers": {}}
|
||||
|
||||
|
||||
def _save_auth_store(auth_store: dict[str, Any]) -> Path:
|
||||
def _save_auth_store(auth_store: Dict[str, Any]) -> Path:
|
||||
auth_file = _auth_file_path()
|
||||
auth_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
auth_store["version"] = AUTH_STORE_VERSION
|
||||
auth_store["updated_at"] = datetime.now(UTC).isoformat()
|
||||
auth_store["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
payload = json.dumps(auth_store, indent=2) + "\n"
|
||||
tmp_path = auth_file.with_name(f"{auth_file.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}")
|
||||
try:
|
||||
@@ -382,7 +387,7 @@ def _save_auth_store(auth_store: dict[str, Any]) -> Path:
|
||||
return auth_file
|
||||
|
||||
|
||||
def _load_provider_state(auth_store: dict[str, Any], provider_id: str) -> dict[str, Any] | None:
|
||||
def _load_provider_state(auth_store: Dict[str, Any], provider_id: str) -> Optional[Dict[str, Any]]:
|
||||
providers = auth_store.get("providers")
|
||||
if not isinstance(providers, dict):
|
||||
return None
|
||||
@@ -390,7 +395,7 @@ def _load_provider_state(auth_store: dict[str, Any], provider_id: str) -> dict[s
|
||||
return dict(state) if isinstance(state, dict) else None
|
||||
|
||||
|
||||
def _save_provider_state(auth_store: dict[str, Any], provider_id: str, state: dict[str, Any]) -> None:
|
||||
def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Dict[str, Any]) -> None:
|
||||
providers = auth_store.setdefault("providers", {})
|
||||
if not isinstance(providers, dict):
|
||||
auth_store["providers"] = {}
|
||||
@@ -399,19 +404,19 @@ def _save_provider_state(auth_store: dict[str, Any], provider_id: str, state: di
|
||||
auth_store["active_provider"] = provider_id
|
||||
|
||||
|
||||
def get_provider_auth_state(provider_id: str) -> dict[str, Any] | None:
|
||||
def get_provider_auth_state(provider_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return persisted auth state for a provider, or None."""
|
||||
auth_store = _load_auth_store()
|
||||
return _load_provider_state(auth_store, provider_id)
|
||||
|
||||
|
||||
def get_active_provider() -> str | None:
|
||||
def get_active_provider() -> Optional[str]:
|
||||
"""Return the currently active provider ID from auth store."""
|
||||
auth_store = _load_auth_store()
|
||||
return auth_store.get("active_provider")
|
||||
|
||||
|
||||
def clear_provider_auth(provider_id: str | None = None) -> bool:
|
||||
def clear_provider_auth(provider_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Clear auth state for a provider. Used by `hermes logout`.
|
||||
If provider_id is None, clears the active provider.
|
||||
@@ -450,12 +455,11 @@ def deactivate_provider() -> None:
|
||||
# Provider Resolution — picks which provider to use
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def resolve_provider(
|
||||
requested: str | None = None,
|
||||
requested: Optional[str] = None,
|
||||
*,
|
||||
explicit_api_key: str | None = None,
|
||||
explicit_base_url: str | None = None,
|
||||
explicit_api_key: Optional[str] = None,
|
||||
explicit_base_url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Determine which inference provider to use.
|
||||
@@ -471,14 +475,9 @@ def resolve_provider(
|
||||
|
||||
# Normalize provider aliases
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai",
|
||||
"z-ai": "zai",
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
"kimi": "kimi-coding",
|
||||
"moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn",
|
||||
"minimax_cn": "minimax-cn",
|
||||
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
|
||||
"kimi": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
|
||||
}
|
||||
normalized = _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
@@ -525,8 +524,7 @@ def resolve_provider(
|
||||
# Timestamp / TTL helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _parse_iso_timestamp(value: Any) -> float | None:
|
||||
def _parse_iso_timestamp(value: Any) -> Optional[float]:
|
||||
if not isinstance(value, str) or not value:
|
||||
return None
|
||||
text = value.strip()
|
||||
@@ -539,7 +537,7 @@ def _parse_iso_timestamp(value: Any) -> float | None:
|
||||
except Exception:
|
||||
return None
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=UTC)
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed.timestamp()
|
||||
|
||||
|
||||
@@ -558,14 +556,14 @@ def _coerce_ttl_seconds(expires_in: Any) -> int:
|
||||
return max(0, ttl)
|
||||
|
||||
|
||||
def _optional_base_url(value: Any) -> str | None:
|
||||
def _optional_base_url(value: Any) -> Optional[str]:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
cleaned = value.strip().rstrip("/")
|
||||
return cleaned if cleaned else None
|
||||
|
||||
|
||||
def _decode_jwt_claims(token: Any) -> dict[str, Any]:
|
||||
def _decode_jwt_claims(token: Any) -> Dict[str, Any]:
|
||||
if not isinstance(token, str) or token.count(".") != 2:
|
||||
return {}
|
||||
payload = token.split(".")[1]
|
||||
@@ -590,7 +588,6 @@ def _codex_access_token_is_expiring(access_token: Any, skew_seconds: int) -> boo
|
||||
# SSH / remote session detection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _is_remote_session() -> bool:
|
||||
"""Detect if running in an SSH session where webbrowser.open() won't work."""
|
||||
return bool(os.getenv("SSH_CLIENT") or os.getenv("SSH_TTY"))
|
||||
@@ -604,10 +601,9 @@ def _is_remote_session() -> bool:
|
||||
# where one app's refresh invalidates the other's session.
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _read_codex_tokens(*, _lock: bool = True) -> dict[str, Any]:
|
||||
def _read_codex_tokens(*, _lock: bool = True) -> Dict[str, Any]:
|
||||
"""Read Codex OAuth tokens from Hermes auth store (~/.hermes/auth.json).
|
||||
|
||||
|
||||
Returns dict with 'tokens' (access_token, refresh_token) and 'last_refresh'.
|
||||
Raises AuthError if no Codex tokens are stored.
|
||||
"""
|
||||
@@ -654,10 +650,10 @@ def _read_codex_tokens(*, _lock: bool = True) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _save_codex_tokens(tokens: dict[str, str], last_refresh: str = None) -> None:
|
||||
def _save_codex_tokens(tokens: Dict[str, str], last_refresh: str = None) -> None:
|
||||
"""Save Codex OAuth tokens to Hermes auth store (~/.hermes/auth.json)."""
|
||||
if last_refresh is None:
|
||||
last_refresh = datetime.now(UTC).isoformat().replace("+00:00", "Z")
|
||||
last_refresh = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
state = _load_provider_state(auth_store, "openai-codex") or {}
|
||||
@@ -669,11 +665,11 @@ def _save_codex_tokens(tokens: dict[str, str], last_refresh: str = None) -> None
|
||||
|
||||
|
||||
def _refresh_codex_auth_tokens(
|
||||
tokens: dict[str, str],
|
||||
tokens: Dict[str, str],
|
||||
timeout_seconds: float,
|
||||
) -> dict[str, str]:
|
||||
) -> Dict[str, str]:
|
||||
"""Refresh Codex access token using the refresh token.
|
||||
|
||||
|
||||
Saves the new tokens to Hermes auth store automatically.
|
||||
"""
|
||||
refresh_token = tokens.get("refresh_token")
|
||||
@@ -750,9 +746,9 @@ def _refresh_codex_auth_tokens(
|
||||
return updated_tokens
|
||||
|
||||
|
||||
def _import_codex_cli_tokens() -> dict[str, str] | None:
|
||||
def _import_codex_cli_tokens() -> Optional[Dict[str, str]]:
|
||||
"""Try to read tokens from ~/.codex/auth.json (Codex CLI shared file).
|
||||
|
||||
|
||||
Returns tokens dict if valid, None otherwise. Does NOT write to the shared file.
|
||||
"""
|
||||
codex_home = os.getenv("CODEX_HOME", "").strip()
|
||||
@@ -778,7 +774,7 @@ def resolve_codex_runtime_credentials(
|
||||
force_refresh: bool = False,
|
||||
refresh_if_expiring: bool = True,
|
||||
refresh_skew_seconds: int = CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""Resolve runtime credentials from Hermes's own Codex token store."""
|
||||
try:
|
||||
data = _read_codex_tokens()
|
||||
@@ -821,7 +817,10 @@ def resolve_codex_runtime_credentials(
|
||||
tokens = _refresh_codex_auth_tokens(tokens, refresh_timeout_seconds)
|
||||
access_token = str(tokens.get("access_token", "") or "").strip()
|
||||
|
||||
base_url = os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/") or DEFAULT_CODEX_BASE_URL
|
||||
base_url = (
|
||||
os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/")
|
||||
or DEFAULT_CODEX_BASE_URL
|
||||
)
|
||||
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
@@ -837,19 +836,24 @@ def resolve_codex_runtime_credentials(
|
||||
# TLS verification helper
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _resolve_verify(
|
||||
*,
|
||||
insecure: bool | None = None,
|
||||
ca_bundle: str | None = None,
|
||||
auth_state: dict[str, Any] | None = None,
|
||||
insecure: Optional[bool] = None,
|
||||
ca_bundle: Optional[str] = None,
|
||||
auth_state: Optional[Dict[str, Any]] = None,
|
||||
) -> bool | str:
|
||||
tls_state = auth_state.get("tls") if isinstance(auth_state, dict) else {}
|
||||
tls_state = tls_state if isinstance(tls_state, dict) else {}
|
||||
|
||||
effective_insecure = bool(insecure) if insecure is not None else bool(tls_state.get("insecure", False))
|
||||
effective_insecure = (
|
||||
bool(insecure) if insecure is not None
|
||||
else bool(tls_state.get("insecure", False))
|
||||
)
|
||||
effective_ca = (
|
||||
ca_bundle or tls_state.get("ca_bundle") or os.getenv("HERMES_CA_BUNDLE") or os.getenv("SSL_CERT_FILE")
|
||||
ca_bundle
|
||||
or tls_state.get("ca_bundle")
|
||||
or os.getenv("HERMES_CA_BUNDLE")
|
||||
or os.getenv("SSL_CERT_FILE")
|
||||
)
|
||||
|
||||
if effective_insecure:
|
||||
@@ -863,13 +867,12 @@ def _resolve_verify(
|
||||
# OAuth Device Code Flow — generic, parameterized by provider
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _request_device_code(
|
||||
client: httpx.Client,
|
||||
portal_base_url: str,
|
||||
client_id: str,
|
||||
scope: str | None,
|
||||
) -> dict[str, Any]:
|
||||
scope: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""POST to the device code endpoint. Returns device_code, user_code, etc."""
|
||||
response = client.post(
|
||||
f"{portal_base_url}/api/oauth/device/code",
|
||||
@@ -882,12 +885,8 @@ def _request_device_code(
|
||||
data = response.json()
|
||||
|
||||
required_fields = [
|
||||
"device_code",
|
||||
"user_code",
|
||||
"verification_uri",
|
||||
"verification_uri_complete",
|
||||
"expires_in",
|
||||
"interval",
|
||||
"device_code", "user_code", "verification_uri",
|
||||
"verification_uri_complete", "expires_in", "interval",
|
||||
]
|
||||
missing = [f for f in required_fields if f not in data]
|
||||
if missing:
|
||||
@@ -902,7 +901,7 @@ def _poll_for_token(
|
||||
device_code: str,
|
||||
expires_in: int,
|
||||
poll_interval: int,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""Poll the token endpoint until the user approves or the code expires."""
|
||||
deadline = time.time() + max(1, expires_in)
|
||||
current_interval = max(1, min(poll_interval, DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS))
|
||||
@@ -948,14 +947,13 @@ def _poll_for_token(
|
||||
# Nous Portal — token refresh, agent key minting, model discovery
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _refresh_access_token(
|
||||
*,
|
||||
client: httpx.Client,
|
||||
portal_base_url: str,
|
||||
client_id: str,
|
||||
refresh_token: str,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
response = client.post(
|
||||
f"{portal_base_url}/api/oauth/token",
|
||||
data={
|
||||
@@ -968,15 +966,15 @@ def _refresh_access_token(
|
||||
if response.status_code == 200:
|
||||
payload = response.json()
|
||||
if "access_token" not in payload:
|
||||
raise AuthError(
|
||||
"Refresh response missing access_token", provider="nous", code="invalid_token", relogin_required=True
|
||||
)
|
||||
raise AuthError("Refresh response missing access_token",
|
||||
provider="nous", code="invalid_token", relogin_required=True)
|
||||
return payload
|
||||
|
||||
try:
|
||||
error_payload = response.json()
|
||||
except Exception as exc:
|
||||
raise AuthError("Refresh token exchange failed", provider="nous", relogin_required=True) from exc
|
||||
raise AuthError("Refresh token exchange failed",
|
||||
provider="nous", relogin_required=True) from exc
|
||||
|
||||
code = str(error_payload.get("error", "invalid_grant"))
|
||||
description = str(error_payload.get("error_description") or "Refresh token exchange failed")
|
||||
@@ -990,7 +988,7 @@ def _mint_agent_key(
|
||||
portal_base_url: str,
|
||||
access_token: str,
|
||||
min_ttl_seconds: int,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""Mint (or reuse) a short-lived inference API key."""
|
||||
response = client.post(
|
||||
f"{portal_base_url}/api/oauth/agent-key",
|
||||
@@ -1001,13 +999,15 @@ def _mint_agent_key(
|
||||
if response.status_code == 200:
|
||||
payload = response.json()
|
||||
if "api_key" not in payload:
|
||||
raise AuthError("Mint response missing api_key", provider="nous", code="server_error")
|
||||
raise AuthError("Mint response missing api_key",
|
||||
provider="nous", code="server_error")
|
||||
return payload
|
||||
|
||||
try:
|
||||
error_payload = response.json()
|
||||
except Exception as exc:
|
||||
raise AuthError("Agent key mint request failed", provider="nous", code="server_error") from exc
|
||||
raise AuthError("Agent key mint request failed",
|
||||
provider="nous", code="server_error") from exc
|
||||
|
||||
code = str(error_payload.get("error", "server_error"))
|
||||
description = str(error_payload.get("error_description") or "Agent key mint request failed")
|
||||
@@ -1021,7 +1021,7 @@ def fetch_nous_models(
|
||||
api_key: str,
|
||||
timeout_seconds: float = 15.0,
|
||||
verify: bool | str = True,
|
||||
) -> list[str]:
|
||||
) -> List[str]:
|
||||
"""Fetch available model IDs from the Nous inference API."""
|
||||
timeout = httpx.Timeout(timeout_seconds)
|
||||
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
|
||||
@@ -1044,7 +1044,7 @@ def fetch_nous_models(
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
|
||||
model_ids: list[str] = []
|
||||
model_ids: List[str] = []
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
@@ -1059,7 +1059,7 @@ def fetch_nous_models(
|
||||
return list(dict.fromkeys(model_ids))
|
||||
|
||||
|
||||
def _agent_key_is_usable(state: dict[str, Any], min_ttl_seconds: int) -> bool:
|
||||
def _agent_key_is_usable(state: Dict[str, Any], min_ttl_seconds: int) -> bool:
|
||||
key = state.get("agent_key")
|
||||
if not isinstance(key, str) or not key.strip():
|
||||
return False
|
||||
@@ -1070,10 +1070,10 @@ def resolve_nous_runtime_credentials(
|
||||
*,
|
||||
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
timeout_seconds: float = 15.0,
|
||||
insecure: bool | None = None,
|
||||
ca_bundle: str | None = None,
|
||||
insecure: Optional[bool] = None,
|
||||
ca_bundle: Optional[str] = None,
|
||||
force_mint: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Resolve Nous inference credentials for runtime use.
|
||||
|
||||
@@ -1092,7 +1092,8 @@ def resolve_nous_runtime_credentials(
|
||||
state = _load_provider_state(auth_store, "nous")
|
||||
|
||||
if not state:
|
||||
raise AuthError("Hermes is not logged into Nous Portal.", provider="nous", relogin_required=True)
|
||||
raise AuthError("Hermes is not logged into Nous Portal.",
|
||||
provider="nous", relogin_required=True)
|
||||
|
||||
portal_base_url = (
|
||||
_optional_base_url(state.get("portal_base_url"))
|
||||
@@ -1142,14 +1143,14 @@ def resolve_nous_runtime_credentials(
|
||||
refresh_token = state.get("refresh_token")
|
||||
|
||||
if not isinstance(access_token, str) or not access_token:
|
||||
raise AuthError("No access token found for Nous Portal login.", provider="nous", relogin_required=True)
|
||||
raise AuthError("No access token found for Nous Portal login.",
|
||||
provider="nous", relogin_required=True)
|
||||
|
||||
# Step 1: refresh access token if expiring
|
||||
if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
|
||||
if not isinstance(refresh_token, str) or not refresh_token:
|
||||
raise AuthError(
|
||||
"Session expired and no refresh token is available.", provider="nous", relogin_required=True
|
||||
)
|
||||
raise AuthError("Session expired and no refresh token is available.",
|
||||
provider="nous", relogin_required=True)
|
||||
|
||||
_oauth_trace(
|
||||
"refresh_start",
|
||||
@@ -1158,12 +1159,10 @@ def resolve_nous_runtime_credentials(
|
||||
refresh_token_fp=_token_fingerprint(refresh_token),
|
||||
)
|
||||
refreshed = _refresh_access_token(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
refresh_token=refresh_token,
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
client_id=client_id, refresh_token=refresh_token,
|
||||
)
|
||||
now = datetime.now(UTC)
|
||||
now = datetime.now(timezone.utc)
|
||||
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
|
||||
previous_refresh_token = refresh_token
|
||||
state["access_token"] = refreshed["access_token"]
|
||||
@@ -1175,7 +1174,9 @@ def resolve_nous_runtime_credentials(
|
||||
inference_base_url = refreshed_url
|
||||
state["obtained_at"] = now.isoformat()
|
||||
state["expires_in"] = access_ttl
|
||||
state["expires_at"] = datetime.fromtimestamp(now.timestamp() + access_ttl, tz=UTC).isoformat()
|
||||
state["expires_at"] = datetime.fromtimestamp(
|
||||
now.timestamp() + access_ttl, tz=timezone.utc
|
||||
).isoformat()
|
||||
access_token = state["access_token"]
|
||||
refresh_token = state["refresh_token"]
|
||||
_oauth_trace(
|
||||
@@ -1190,7 +1191,7 @@ def resolve_nous_runtime_credentials(
|
||||
|
||||
# Step 2: mint agent key if missing/expiring
|
||||
used_cached_key = False
|
||||
mint_payload: dict[str, Any] | None = None
|
||||
mint_payload: Optional[Dict[str, Any]] = None
|
||||
|
||||
if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
|
||||
used_cached_key = True
|
||||
@@ -1203,10 +1204,8 @@ def resolve_nous_runtime_credentials(
|
||||
access_token_fp=_token_fingerprint(access_token),
|
||||
)
|
||||
mint_payload = _mint_agent_key(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
access_token=access_token,
|
||||
min_ttl_seconds=min_key_ttl_seconds,
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
|
||||
)
|
||||
except AuthError as exc:
|
||||
_oauth_trace(
|
||||
@@ -1228,12 +1227,10 @@ def resolve_nous_runtime_credentials(
|
||||
refresh_token_fp=_token_fingerprint(latest_refresh_token),
|
||||
)
|
||||
refreshed = _refresh_access_token(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
refresh_token=latest_refresh_token,
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
client_id=client_id, refresh_token=latest_refresh_token,
|
||||
)
|
||||
now = datetime.now(UTC)
|
||||
now = datetime.now(timezone.utc)
|
||||
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
|
||||
state["access_token"] = refreshed["access_token"]
|
||||
state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token
|
||||
@@ -1244,7 +1241,9 @@ def resolve_nous_runtime_credentials(
|
||||
inference_base_url = refreshed_url
|
||||
state["obtained_at"] = now.isoformat()
|
||||
state["expires_in"] = access_ttl
|
||||
state["expires_at"] = datetime.fromtimestamp(now.timestamp() + access_ttl, tz=UTC).isoformat()
|
||||
state["expires_at"] = datetime.fromtimestamp(
|
||||
now.timestamp() + access_ttl, tz=timezone.utc
|
||||
).isoformat()
|
||||
access_token = state["access_token"]
|
||||
refresh_token = state["refresh_token"]
|
||||
_oauth_trace(
|
||||
@@ -1258,16 +1257,14 @@ def resolve_nous_runtime_credentials(
|
||||
_persist_state("post_refresh_mint_retry")
|
||||
|
||||
mint_payload = _mint_agent_key(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
access_token=access_token,
|
||||
min_ttl_seconds=min_key_ttl_seconds,
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if mint_payload is not None:
|
||||
now = datetime.now(UTC)
|
||||
now = datetime.now(timezone.utc)
|
||||
state["agent_key"] = mint_payload.get("api_key")
|
||||
state["agent_key_id"] = mint_payload.get("key_id")
|
||||
state["agent_key_expires_at"] = mint_payload.get("expires_at")
|
||||
@@ -1296,7 +1293,8 @@ def resolve_nous_runtime_credentials(
|
||||
|
||||
api_key = state.get("agent_key")
|
||||
if not isinstance(api_key, str) or not api_key:
|
||||
raise AuthError("Failed to resolve a Nous inference API key", provider="nous", code="server_error")
|
||||
raise AuthError("Failed to resolve a Nous inference API key",
|
||||
provider="nous", code="server_error")
|
||||
|
||||
expires_at = state.get("agent_key_expires_at")
|
||||
expires_epoch = _parse_iso_timestamp(expires_at)
|
||||
@@ -1321,8 +1319,7 @@ def resolve_nous_runtime_credentials(
|
||||
# Status helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_nous_auth_status() -> dict[str, Any]:
|
||||
def get_nous_auth_status() -> Dict[str, Any]:
|
||||
"""Status snapshot for `hermes status` output."""
|
||||
state = get_provider_auth_state("nous")
|
||||
if not state:
|
||||
@@ -1344,7 +1341,7 @@ def get_nous_auth_status() -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_codex_auth_status() -> dict[str, Any]:
|
||||
def get_codex_auth_status() -> Dict[str, Any]:
|
||||
"""Status snapshot for Codex auth."""
|
||||
try:
|
||||
creds = resolve_codex_runtime_credentials()
|
||||
@@ -1363,7 +1360,7 @@ def get_codex_auth_status() -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_api_key_provider_status(provider_id: str) -> dict[str, Any]:
|
||||
def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
|
||||
"""Status snapshot for API-key providers (z.ai, Kimi, MiniMax)."""
|
||||
pconfig = PROVIDER_REGISTRY.get(provider_id)
|
||||
if not pconfig or pconfig.auth_type != "api_key":
|
||||
@@ -1399,7 +1396,7 @@ def get_api_key_provider_status(provider_id: str) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_auth_status(provider_id: str | None = None) -> dict[str, Any]:
|
||||
def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Generic auth status dispatcher."""
|
||||
target = provider_id or get_active_provider()
|
||||
if target == "nous":
|
||||
@@ -1413,7 +1410,7 @@ def get_auth_status(provider_id: str | None = None) -> dict[str, Any]:
|
||||
return {"logged_in": False}
|
||||
|
||||
|
||||
def resolve_api_key_provider_credentials(provider_id: str) -> dict[str, Any]:
|
||||
def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
|
||||
"""Resolve API key and base URL for an API-key provider.
|
||||
|
||||
Returns dict with: provider, api_key, base_url, source.
|
||||
@@ -1458,8 +1455,7 @@ def resolve_api_key_provider_credentials(provider_id: str) -> dict[str, Any]:
|
||||
# External credential detection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def detect_external_credentials() -> list[dict[str, Any]]:
|
||||
def detect_external_credentials() -> List[Dict[str, Any]]:
|
||||
"""Scan for credentials from other CLI tools that Hermes can reuse.
|
||||
|
||||
Returns a list of dicts, each with:
|
||||
@@ -1467,19 +1463,17 @@ def detect_external_credentials() -> list[dict[str, Any]]:
|
||||
- path: str -- filesystem path where creds were found
|
||||
- label: str -- human-friendly description for the setup UI
|
||||
"""
|
||||
found: list[dict[str, Any]] = []
|
||||
found: List[Dict[str, Any]] = []
|
||||
|
||||
# Codex CLI: ~/.codex/auth.json (importable, not shared)
|
||||
cli_tokens = _import_codex_cli_tokens()
|
||||
if cli_tokens:
|
||||
codex_path = Path.home() / ".codex" / "auth.json"
|
||||
found.append(
|
||||
{
|
||||
"provider": "openai-codex",
|
||||
"path": str(codex_path),
|
||||
"label": f"Codex CLI credentials found ({codex_path}) — run `hermes login` to create a separate session",
|
||||
}
|
||||
)
|
||||
found.append({
|
||||
"provider": "openai-codex",
|
||||
"path": str(codex_path),
|
||||
"label": f"Codex CLI credentials found ({codex_path}) — run `hermes login` to create a separate session",
|
||||
})
|
||||
|
||||
return found
|
||||
|
||||
@@ -1488,7 +1482,6 @@ def detect_external_credentials() -> list[dict[str, Any]]:
|
||||
# CLI Commands — login / logout
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _update_config_for_provider(provider_id: str, inference_base_url: str) -> Path:
|
||||
"""Update config.yaml and auth.json to reflect the active provider."""
|
||||
# Set active_provider in auth.json so auto-resolution picks this provider
|
||||
@@ -1501,7 +1494,7 @@ def _update_config_for_provider(provider_id: str, inference_base_url: str) -> Pa
|
||||
config_path = get_config_path()
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
config: Dict[str, Any] = {}
|
||||
if config_path.exists():
|
||||
try:
|
||||
loaded = yaml.safe_load(config_path.read_text()) or {}
|
||||
@@ -1549,7 +1542,7 @@ def _reset_config_provider() -> Path:
|
||||
return config_path
|
||||
|
||||
|
||||
def _prompt_model_selection(model_ids: list[str], current_model: str = "") -> str | None:
|
||||
def _prompt_model_selection(model_ids: List[str], current_model: str = "") -> Optional[str]:
|
||||
"""Interactive model selection. Puts current_model first with a marker. Returns chosen model ID or None."""
|
||||
# Reorder: current model first, then the rest (deduplicated)
|
||||
ordered = []
|
||||
@@ -1571,7 +1564,6 @@ def _prompt_model_selection(model_ids: list[str], current_model: str = "") -> st
|
||||
# Try arrow-key menu first, fall back to number input
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
|
||||
choices = [f" {_label(mid)}" for mid in ordered]
|
||||
choices.append(" Enter custom model name")
|
||||
choices.append(" Skip (keep current)")
|
||||
@@ -1629,7 +1621,7 @@ def _prompt_model_selection(model_ids: list[str], current_model: str = "") -> st
|
||||
|
||||
def _save_model_choice(model_id: str) -> None:
|
||||
"""Save the selected model to config.yaml and .env."""
|
||||
from hermes_cli.config import load_config, save_config, save_env_value
|
||||
from hermes_cli.config import save_config, load_config, save_env_value
|
||||
|
||||
config = load_config()
|
||||
# Handle both string and dict model formats
|
||||
@@ -1701,11 +1693,11 @@ def _login_openai_codex(args, pconfig: ProviderConfig) -> None:
|
||||
config_path = _update_config_for_provider("openai-codex", creds.get("base_url", DEFAULT_CODEX_BASE_URL))
|
||||
print()
|
||||
print("Login successful!")
|
||||
print(" Auth state: ~/.hermes/auth.json")
|
||||
print(f" Auth state: ~/.hermes/auth.json")
|
||||
print(f" Config updated: {config_path} (model.provider=openai-codex)")
|
||||
|
||||
|
||||
def _codex_device_code_login() -> dict[str, Any]:
|
||||
def _codex_device_code_login() -> Dict[str, Any]:
|
||||
"""Run the OpenAI device code login flow and return credentials dict."""
|
||||
import time as _time
|
||||
|
||||
@@ -1723,15 +1715,13 @@ def _codex_device_code_login() -> dict[str, Any]:
|
||||
except Exception as exc:
|
||||
raise AuthError(
|
||||
f"Failed to request device code: {exc}",
|
||||
provider="openai-codex",
|
||||
code="device_code_request_failed",
|
||||
provider="openai-codex", code="device_code_request_failed",
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise AuthError(
|
||||
f"Device code request returned status {resp.status_code}.",
|
||||
provider="openai-codex",
|
||||
code="device_code_request_error",
|
||||
provider="openai-codex", code="device_code_request_error",
|
||||
)
|
||||
|
||||
device_data = resp.json()
|
||||
@@ -1742,15 +1732,14 @@ def _codex_device_code_login() -> dict[str, Any]:
|
||||
if not user_code or not device_auth_id:
|
||||
raise AuthError(
|
||||
"Device code response missing required fields.",
|
||||
provider="openai-codex",
|
||||
code="device_code_incomplete",
|
||||
provider="openai-codex", code="device_code_incomplete",
|
||||
)
|
||||
|
||||
# Step 2: Show user the code
|
||||
print("To continue, follow these steps:\n")
|
||||
print(" 1. Open this URL in your browser:")
|
||||
print(f" 1. Open this URL in your browser:")
|
||||
print(f" \033[94m{issuer}/codex/device\033[0m\n")
|
||||
print(" 2. Enter this code:")
|
||||
print(f" 2. Enter this code:")
|
||||
print(f" \033[94m{user_code}\033[0m\n")
|
||||
print("Waiting for sign-in... (press Ctrl+C to cancel)")
|
||||
|
||||
@@ -1777,8 +1766,7 @@ def _codex_device_code_login() -> dict[str, Any]:
|
||||
else:
|
||||
raise AuthError(
|
||||
f"Device auth polling returned status {poll_resp.status_code}.",
|
||||
provider="openai-codex",
|
||||
code="device_code_poll_error",
|
||||
provider="openai-codex", code="device_code_poll_error",
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("\nLogin cancelled.")
|
||||
@@ -1787,8 +1775,7 @@ def _codex_device_code_login() -> dict[str, Any]:
|
||||
if code_resp is None:
|
||||
raise AuthError(
|
||||
"Login timed out after 15 minutes.",
|
||||
provider="openai-codex",
|
||||
code="device_code_timeout",
|
||||
provider="openai-codex", code="device_code_timeout",
|
||||
)
|
||||
|
||||
# Step 4: Exchange authorization code for tokens
|
||||
@@ -1799,8 +1786,7 @@ def _codex_device_code_login() -> dict[str, Any]:
|
||||
if not authorization_code or not code_verifier:
|
||||
raise AuthError(
|
||||
"Device auth response missing authorization_code or code_verifier.",
|
||||
provider="openai-codex",
|
||||
code="device_code_incomplete_exchange",
|
||||
provider="openai-codex", code="device_code_incomplete_exchange",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1819,15 +1805,13 @@ def _codex_device_code_login() -> dict[str, Any]:
|
||||
except Exception as exc:
|
||||
raise AuthError(
|
||||
f"Token exchange failed: {exc}",
|
||||
provider="openai-codex",
|
||||
code="token_exchange_failed",
|
||||
provider="openai-codex", code="token_exchange_failed",
|
||||
)
|
||||
|
||||
if token_resp.status_code != 200:
|
||||
raise AuthError(
|
||||
f"Token exchange returned status {token_resp.status_code}.",
|
||||
provider="openai-codex",
|
||||
code="token_exchange_error",
|
||||
provider="openai-codex", code="token_exchange_error",
|
||||
)
|
||||
|
||||
tokens = token_resp.json()
|
||||
@@ -1837,12 +1821,14 @@ def _codex_device_code_login() -> dict[str, Any]:
|
||||
if not access_token:
|
||||
raise AuthError(
|
||||
"Token exchange did not return an access_token.",
|
||||
provider="openai-codex",
|
||||
code="token_exchange_no_access_token",
|
||||
provider="openai-codex", code="token_exchange_no_access_token",
|
||||
)
|
||||
|
||||
# Return tokens for the caller to persist (no longer writes to ~/.codex/)
|
||||
base_url = os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/") or DEFAULT_CODEX_BASE_URL
|
||||
base_url = (
|
||||
os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/")
|
||||
or DEFAULT_CODEX_BASE_URL
|
||||
)
|
||||
|
||||
return {
|
||||
"tokens": {
|
||||
@@ -1850,7 +1836,7 @@ def _codex_device_code_login() -> dict[str, Any]:
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
"base_url": base_url,
|
||||
"last_refresh": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
||||
"last_refresh": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||
"auth_mode": "chatgpt",
|
||||
"source": "device-code",
|
||||
}
|
||||
@@ -1865,7 +1851,9 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
or pconfig.portal_base_url
|
||||
).rstrip("/")
|
||||
requested_inference_url = (
|
||||
getattr(args, "inference_url", None) or os.getenv("NOUS_INFERENCE_BASE_URL") or pconfig.inference_base_url
|
||||
getattr(args, "inference_url", None)
|
||||
or os.getenv("NOUS_INFERENCE_BASE_URL")
|
||||
or pconfig.inference_base_url
|
||||
).rstrip("/")
|
||||
client_id = getattr(args, "client_id", None) or pconfig.client_id
|
||||
scope = getattr(args, "scope", None) or pconfig.scope
|
||||
@@ -1874,7 +1862,11 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
timeout = httpx.Timeout(timeout_seconds)
|
||||
|
||||
insecure = bool(getattr(args, "insecure", False))
|
||||
ca_bundle = getattr(args, "ca_bundle", None) or os.getenv("HERMES_CA_BUNDLE") or os.getenv("SSL_CERT_FILE")
|
||||
ca_bundle = (
|
||||
getattr(args, "ca_bundle", None)
|
||||
or os.getenv("HERMES_CA_BUNDLE")
|
||||
or os.getenv("SSL_CERT_FILE")
|
||||
)
|
||||
verify: bool | str = False if insecure else (ca_bundle if ca_bundle else True)
|
||||
|
||||
# Skip browser open in SSH sessions
|
||||
@@ -1891,10 +1883,8 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
try:
|
||||
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
|
||||
device_data = _request_device_code(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
scope=scope,
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
client_id=client_id, scope=scope,
|
||||
)
|
||||
|
||||
verification_url = str(device_data["verification_uri_complete"])
|
||||
@@ -1918,19 +1908,19 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
print(f"Waiting for approval (polling every {effective_interval}s)...")
|
||||
|
||||
token_data = _poll_for_token(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
device_code=str(device_data["device_code"]),
|
||||
expires_in=expires_in,
|
||||
poll_interval=interval,
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
client_id=client_id, device_code=str(device_data["device_code"]),
|
||||
expires_in=expires_in, poll_interval=interval,
|
||||
)
|
||||
|
||||
# Process token response
|
||||
now = datetime.now(UTC)
|
||||
now = datetime.now(timezone.utc)
|
||||
token_expires_in = _coerce_ttl_seconds(token_data.get("expires_in", 0))
|
||||
expires_at = now.timestamp() + token_expires_in
|
||||
inference_base_url = _optional_base_url(token_data.get("inference_base_url")) or requested_inference_url
|
||||
inference_base_url = (
|
||||
_optional_base_url(token_data.get("inference_base_url"))
|
||||
or requested_inference_url
|
||||
)
|
||||
if inference_base_url != requested_inference_url:
|
||||
print(f"Using portal-provided inference URL: {inference_base_url}")
|
||||
|
||||
@@ -1943,7 +1933,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
"access_token": token_data["access_token"],
|
||||
"refresh_token": token_data.get("refresh_token"),
|
||||
"obtained_at": now.isoformat(),
|
||||
"expires_at": datetime.fromtimestamp(expires_at, tz=UTC).isoformat(),
|
||||
"expires_at": datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
|
||||
"expires_in": token_expires_in,
|
||||
"tls": {
|
||||
"insecure": verify is False,
|
||||
@@ -1974,13 +1964,13 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
runtime_creds = resolve_nous_runtime_credentials(
|
||||
min_key_ttl_seconds=5 * 60,
|
||||
timeout_seconds=timeout_seconds,
|
||||
insecure=insecure,
|
||||
ca_bundle=ca_bundle,
|
||||
insecure=insecure, ca_bundle=ca_bundle,
|
||||
)
|
||||
runtime_key = runtime_creds.get("api_key")
|
||||
runtime_base_url = runtime_creds.get("base_url") or inference_base_url
|
||||
if not isinstance(runtime_key, str) or not runtime_key:
|
||||
raise AuthError("No runtime API key available to fetch models", provider="nous", code="invalid_token")
|
||||
raise AuthError("No runtime API key available to fetch models",
|
||||
provider="nous", code="invalid_token")
|
||||
|
||||
model_ids = fetch_nous_models(
|
||||
inference_base_url=runtime_base_url,
|
||||
|
||||
@@ -9,13 +9,15 @@ import os
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from prompt_toolkit import print_formatted_text as _pt_print
|
||||
from prompt_toolkit.formatted_text import ANSI as _PT_ANSI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from prompt_toolkit import print_formatted_text as _pt_print
|
||||
from prompt_toolkit.formatted_text import ANSI as _PT_ANSI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -75,8 +77,7 @@ COMPACT_BANNER = """
|
||||
# Skills scanning
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def get_available_skills() -> dict[str, list[str]]:
|
||||
def get_available_skills() -> Dict[str, List[str]]:
|
||||
"""Scan ~/.hermes/skills/ and return skills grouped by category."""
|
||||
import os
|
||||
|
||||
@@ -109,7 +110,7 @@ def get_available_skills() -> dict[str, list[str]]:
|
||||
_UPDATE_CHECK_CACHE_SECONDS = 6 * 3600
|
||||
|
||||
|
||||
def check_for_updates() -> int | None:
|
||||
def check_for_updates() -> Optional[int]:
|
||||
"""Check how many commits behind origin/main the local repo is.
|
||||
|
||||
Does a ``git fetch`` at most once every 6 hours (cached to
|
||||
@@ -138,8 +139,7 @@ def check_for_updates() -> int | None:
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "fetch", "origin", "--quiet"],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
capture_output=True, timeout=10,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
except Exception:
|
||||
@@ -149,9 +149,7 @@ def check_for_updates() -> int | None:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-list", "--count", "HEAD..origin/main"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
capture_output=True, text=True, timeout=5,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
@@ -174,7 +172,6 @@ def check_for_updates() -> int | None:
|
||||
# Welcome banner
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _format_context_length(tokens: int) -> str:
|
||||
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
|
||||
if tokens >= 1_000_000:
|
||||
@@ -186,16 +183,12 @@ def _format_context_length(tokens: int) -> str:
|
||||
return str(tokens)
|
||||
|
||||
|
||||
def build_welcome_banner(
|
||||
console: Console,
|
||||
model: str,
|
||||
cwd: str,
|
||||
tools: list[dict] = None,
|
||||
enabled_toolsets: list[str] = None,
|
||||
session_id: str = None,
|
||||
get_toolset_for_tool=None,
|
||||
context_length: int = None,
|
||||
):
|
||||
def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
tools: List[dict] = None,
|
||||
enabled_toolsets: List[str] = None,
|
||||
session_id: str = None,
|
||||
get_toolset_for_tool=None,
|
||||
context_length: int = None):
|
||||
"""Build and print a welcome banner with caduceus on left and info on right.
|
||||
|
||||
Args:
|
||||
@@ -208,8 +201,7 @@ def build_welcome_banner(
|
||||
get_toolset_for_tool: Callable to map tool name -> toolset name.
|
||||
context_length: Model's context window size in tokens.
|
||||
"""
|
||||
from model_tools import check_tool_availability
|
||||
|
||||
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
|
||||
if get_toolset_for_tool is None:
|
||||
from model_tools import get_toolset_for_tool
|
||||
|
||||
@@ -229,9 +221,7 @@ def build_welcome_banner(
|
||||
model_short = model.split("/")[-1] if "/" in model else model
|
||||
if len(model_short) > 28:
|
||||
model_short = model_short[:25] + "..."
|
||||
ctx_str = (
|
||||
f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
|
||||
)
|
||||
ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
|
||||
left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
|
||||
left_lines.append(f"[dim #B8860B]{cwd}[/]")
|
||||
if session_id:
|
||||
@@ -239,7 +229,7 @@ def build_welcome_banner(
|
||||
left_content = "\n".join(left_lines)
|
||||
|
||||
right_lines = ["[bold #FFBF00]Available Tools[/]"]
|
||||
toolsets_dict: dict[str, list] = {}
|
||||
toolsets_dict: Dict[str, list] = {}
|
||||
|
||||
for tool in tools:
|
||||
tool_name = tool["function"]["name"]
|
||||
@@ -296,7 +286,6 @@ def build_welcome_banner(
|
||||
# MCP Servers section (only if configured)
|
||||
try:
|
||||
from tools.mcp_tool import get_mcp_status
|
||||
|
||||
mcp_status = get_mcp_status()
|
||||
except Exception:
|
||||
mcp_status = []
|
||||
@@ -311,7 +300,10 @@ def build_welcome_banner(
|
||||
f"[dim #B8860B]—[/] [#FFF8DC]{srv['tools']} tool(s)[/]"
|
||||
)
|
||||
else:
|
||||
right_lines.append(f"[red]{srv['name']}[/] [dim]({srv['transport']})[/] [red]— failed[/]")
|
||||
right_lines.append(
|
||||
f"[red]{srv['name']}[/] [dim]({srv['transport']})[/] "
|
||||
f"[red]— failed[/]"
|
||||
)
|
||||
|
||||
right_lines.append("")
|
||||
right_lines.append("[bold #FFBF00]Available Skills[/]")
|
||||
|
||||
@@ -9,7 +9,7 @@ with the TUI.
|
||||
import queue
|
||||
import time as _time
|
||||
|
||||
from hermes_cli.banner import _DIM, _RST, cprint
|
||||
from hermes_cli.banner import cprint, _DIM, _RST
|
||||
|
||||
|
||||
def clarify_callback(cli, question, choices):
|
||||
@@ -33,7 +33,7 @@ def clarify_callback(cli, question, choices):
|
||||
cli._clarify_deadline = _time.monotonic() + timeout
|
||||
cli._clarify_freetext = is_open_ended
|
||||
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
while True:
|
||||
@@ -45,13 +45,13 @@ def clarify_callback(cli, question, choices):
|
||||
remaining = cli._clarify_deadline - _time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
cli._clarify_state = None
|
||||
cli._clarify_freetext = False
|
||||
cli._clarify_deadline = 0
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
cprint(f"\n{_DIM}(clarify timed out after {timeout}s — agent will decide){_RST}")
|
||||
return (
|
||||
@@ -71,7 +71,7 @@ def sudo_password_callback(cli) -> str:
|
||||
cli._sudo_state = {"response_queue": response_queue}
|
||||
cli._sudo_deadline = _time.monotonic() + timeout
|
||||
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
while True:
|
||||
@@ -79,7 +79,7 @@ def sudo_password_callback(cli) -> str:
|
||||
result = response_queue.get(timeout=1)
|
||||
cli._sudo_state = None
|
||||
cli._sudo_deadline = 0
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
if result:
|
||||
cprint(f"\n{_DIM} ✓ Password received (cached for session){_RST}")
|
||||
@@ -90,12 +90,12 @@ def sudo_password_callback(cli) -> str:
|
||||
remaining = cli._sudo_deadline - _time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
cli._sudo_state = None
|
||||
cli._sudo_deadline = 0
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
cprint(f"\n{_DIM} ⏱ Timeout — continuing without sudo{_RST}")
|
||||
return ""
|
||||
@@ -119,7 +119,7 @@ def approval_callback(cli, command: str, description: str) -> str:
|
||||
}
|
||||
cli._approval_deadline = _time.monotonic() + timeout
|
||||
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
while True:
|
||||
@@ -127,19 +127,19 @@ def approval_callback(cli, command: str, description: str) -> str:
|
||||
result = response_queue.get(timeout=1)
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
return result
|
||||
except queue.Empty:
|
||||
remaining = cli._approval_deadline - _time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
cli._app.invalidate()
|
||||
cprint(f"\n{_DIM} ⏱ Timeout — denying command{_RST}")
|
||||
return "deny"
|
||||
|
||||
@@ -51,7 +51,6 @@ def has_clipboard_image() -> bool:
|
||||
|
||||
# ── macOS ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _macos_save(dest: Path) -> bool:
|
||||
"""Try pngpaste first (fast, handles more formats), fall back to osascript."""
|
||||
return _macos_pngpaste(dest) or _macos_osascript(dest)
|
||||
@@ -62,9 +61,7 @@ def _macos_has_image() -> bool:
|
||||
try:
|
||||
info = subprocess.run(
|
||||
["osascript", "-e", "clipboard info"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
return "«class PNGf»" in info.stdout or "«class TIFF»" in info.stdout
|
||||
except Exception:
|
||||
@@ -76,8 +73,7 @@ def _macos_pngpaste(dest: Path) -> bool:
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["pngpaste", str(dest)],
|
||||
capture_output=True,
|
||||
timeout=3,
|
||||
capture_output=True, timeout=3,
|
||||
)
|
||||
if r.returncode == 0 and dest.exists() and dest.stat().st_size > 0:
|
||||
return True
|
||||
@@ -95,21 +91,19 @@ def _macos_osascript(dest: Path) -> bool:
|
||||
|
||||
# Extract as PNG
|
||||
script = (
|
||||
"try\n"
|
||||
" set imgData to the clipboard as «class PNGf»\n"
|
||||
'try\n'
|
||||
' set imgData to the clipboard as «class PNGf»\n'
|
||||
f' set f to open for access POSIX file "{dest}" with write permission\n'
|
||||
" write imgData to f\n"
|
||||
" close access f\n"
|
||||
"on error\n"
|
||||
' write imgData to f\n'
|
||||
' close access f\n'
|
||||
'on error\n'
|
||||
' return "fail"\n'
|
||||
"end try\n"
|
||||
'end try\n'
|
||||
)
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["osascript", "-e", script],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
if r.returncode == 0 and "fail" not in r.stdout and dest.exists() and dest.stat().st_size > 0:
|
||||
return True
|
||||
@@ -120,14 +114,13 @@ def _macos_osascript(dest: Path) -> bool:
|
||||
|
||||
# ── Linux ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _is_wsl() -> bool:
|
||||
"""Detect if running inside WSL (1 or 2)."""
|
||||
global _wsl_detected
|
||||
if _wsl_detected is not None:
|
||||
return _wsl_detected
|
||||
try:
|
||||
with open("/proc/version") as f:
|
||||
with open("/proc/version", "r") as f:
|
||||
_wsl_detected = "microsoft" in f.read().lower()
|
||||
except Exception:
|
||||
_wsl_detected = False
|
||||
@@ -152,7 +145,10 @@ def _linux_save(dest: Path) -> bool:
|
||||
|
||||
# PowerShell script: get clipboard image as base64-encoded PNG on stdout.
|
||||
# Using .NET System.Windows.Forms.Clipboard — always available on Windows.
|
||||
_PS_CHECK_IMAGE = "Add-Type -AssemblyName System.Windows.Forms;[System.Windows.Forms.Clipboard]::ContainsImage()"
|
||||
_PS_CHECK_IMAGE = (
|
||||
"Add-Type -AssemblyName System.Windows.Forms;"
|
||||
"[System.Windows.Forms.Clipboard]::ContainsImage()"
|
||||
)
|
||||
|
||||
_PS_EXTRACT_IMAGE = (
|
||||
"Add-Type -AssemblyName System.Windows.Forms;"
|
||||
@@ -169,10 +165,9 @@ def _wsl_has_image() -> bool:
|
||||
"""Check if Windows clipboard has an image (via powershell.exe)."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command", _PS_CHECK_IMAGE],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=8,
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_CHECK_IMAGE],
|
||||
capture_output=True, text=True, timeout=8,
|
||||
)
|
||||
return r.returncode == 0 and "True" in r.stdout
|
||||
except FileNotFoundError:
|
||||
@@ -186,10 +181,9 @@ def _wsl_save(dest: Path) -> bool:
|
||||
"""Extract clipboard image via powershell.exe → base64 → decode to PNG."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command", _PS_EXTRACT_IMAGE],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_EXTRACT_IMAGE],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if r.returncode != 0:
|
||||
return False
|
||||
@@ -212,17 +206,16 @@ def _wsl_save(dest: Path) -> bool:
|
||||
|
||||
# ── Wayland (wl-paste) ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _wayland_has_image() -> bool:
|
||||
"""Check if Wayland clipboard has image content."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["wl-paste", "--list-types"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
return r.returncode == 0 and any(
|
||||
t.startswith("image/") for t in r.stdout.splitlines()
|
||||
)
|
||||
return r.returncode == 0 and any(t.startswith("image/") for t in r.stdout.splitlines())
|
||||
except FileNotFoundError:
|
||||
logger.debug("wl-paste not installed — Wayland clipboard unavailable")
|
||||
except Exception:
|
||||
@@ -236,9 +229,7 @@ def _wayland_save(dest: Path) -> bool:
|
||||
# Check available MIME types
|
||||
types_r = subprocess.run(
|
||||
["wl-paste", "--list-types"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
if types_r.returncode != 0:
|
||||
return False
|
||||
@@ -246,7 +237,8 @@ def _wayland_save(dest: Path) -> bool:
|
||||
|
||||
# Prefer PNG, fall back to other image formats
|
||||
mime = None
|
||||
for preferred in ("image/png", "image/jpeg", "image/bmp", "image/gif", "image/webp"):
|
||||
for preferred in ("image/png", "image/jpeg", "image/bmp",
|
||||
"image/gif", "image/webp"):
|
||||
if preferred in types:
|
||||
mime = preferred
|
||||
break
|
||||
@@ -258,10 +250,7 @@ def _wayland_save(dest: Path) -> bool:
|
||||
with open(dest, "wb") as f:
|
||||
subprocess.run(
|
||||
["wl-paste", "--type", mime],
|
||||
stdout=f,
|
||||
stderr=subprocess.DEVNULL,
|
||||
timeout=5,
|
||||
check=True,
|
||||
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
|
||||
)
|
||||
|
||||
if not dest.exists() or dest.stat().st_size == 0:
|
||||
@@ -287,7 +276,6 @@ def _convert_to_png(path: Path) -> bool:
|
||||
# Try Pillow first (likely installed in the venv)
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(path)
|
||||
img.save(path, "PNG")
|
||||
return True
|
||||
@@ -302,8 +290,7 @@ def _convert_to_png(path: Path) -> bool:
|
||||
path.rename(tmp)
|
||||
r = subprocess.run(
|
||||
["convert", str(tmp), "png:" + str(path)],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
tmp.unlink(missing_ok=True)
|
||||
if r.returncode == 0 and path.exists() and path.stat().st_size > 0:
|
||||
@@ -323,15 +310,12 @@ def _convert_to_png(path: Path) -> bool:
|
||||
|
||||
# ── X11 (xclip) ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _xclip_has_image() -> bool:
|
||||
"""Check if X11 clipboard has image content."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
return r.returncode == 0 and "image/png" in r.stdout
|
||||
except FileNotFoundError:
|
||||
@@ -347,9 +331,7 @@ def _xclip_save(dest: Path) -> bool:
|
||||
try:
|
||||
targets = subprocess.run(
|
||||
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
if "image/png" not in targets.stdout:
|
||||
return False
|
||||
@@ -364,10 +346,7 @@ def _xclip_save(dest: Path) -> bool:
|
||||
with open(dest, "wb") as f:
|
||||
subprocess.run(
|
||||
["xclip", "-selection", "clipboard", "-t", "image/png", "-o"],
|
||||
stdout=f,
|
||||
stderr=subprocess.DEVNULL,
|
||||
timeout=5,
|
||||
check=True,
|
||||
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
|
||||
)
|
||||
if dest.exists() and dest.stat().st_size > 0:
|
||||
return True
|
||||
|
||||
@@ -4,12 +4,14 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CODEX_MODELS: list[str] = [
|
||||
DEFAULT_CODEX_MODELS: List[str] = [
|
||||
"gpt-5.3-codex",
|
||||
"gpt-5.2-codex",
|
||||
"gpt-5.1-codex-max",
|
||||
@@ -17,11 +19,10 @@ DEFAULT_CODEX_MODELS: list[str] = [
|
||||
]
|
||||
|
||||
|
||||
def _fetch_models_from_api(access_token: str) -> list[str]:
|
||||
def _fetch_models_from_api(access_token: str) -> List[str]:
|
||||
"""Fetch available models from the Codex API. Returns visible models sorted by priority."""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
resp = httpx.get(
|
||||
"https://chatgpt.com/backend-api/codex/models?client_version=1.0.0",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
@@ -46,7 +47,7 @@ def _fetch_models_from_api(access_token: str) -> list[str]:
|
||||
if item.get("supported_in_api") is False:
|
||||
continue
|
||||
visibility = item.get("visibility", "")
|
||||
if isinstance(visibility, str) and visibility.strip().lower() == "hidden":
|
||||
if isinstance(visibility, str) and visibility.strip().lower() == "hide":
|
||||
continue
|
||||
priority = item.get("priority")
|
||||
rank = int(priority) if isinstance(priority, (int, float)) else 10_000
|
||||
@@ -56,7 +57,7 @@ def _fetch_models_from_api(access_token: str) -> list[str]:
|
||||
return [slug for _, slug in sortable]
|
||||
|
||||
|
||||
def _read_default_model(codex_home: Path) -> str | None:
|
||||
def _read_default_model(codex_home: Path) -> Optional[str]:
|
||||
config_path = codex_home / "config.toml"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
@@ -74,7 +75,7 @@ def _read_default_model(codex_home: Path) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _read_cache_models(codex_home: Path) -> list[str]:
|
||||
def _read_cache_models(codex_home: Path) -> List[str]:
|
||||
cache_path = codex_home / "models_cache.json"
|
||||
if not cache_path.exists():
|
||||
return []
|
||||
@@ -103,22 +104,22 @@ def _read_cache_models(codex_home: Path) -> list[str]:
|
||||
sortable.append((rank, slug))
|
||||
|
||||
sortable.sort(key=lambda item: (item[0], item[1]))
|
||||
deduped: list[str] = []
|
||||
deduped: List[str] = []
|
||||
for _, slug in sortable:
|
||||
if slug not in deduped:
|
||||
deduped.append(slug)
|
||||
return deduped
|
||||
|
||||
|
||||
def get_codex_model_ids(access_token: str | None = None) -> list[str]:
|
||||
def get_codex_model_ids(access_token: Optional[str] = None) -> List[str]:
|
||||
"""Return available Codex model IDs, trying API first, then local sources.
|
||||
|
||||
|
||||
Resolution order: API (live, if token provided) > config.toml default >
|
||||
local cache > hardcoded defaults.
|
||||
"""
|
||||
codex_home_str = os.getenv("CODEX_HOME", "").strip() or str(Path.home() / ".codex")
|
||||
codex_home = Path(codex_home_str).expanduser()
|
||||
ordered: list[str] = []
|
||||
ordered: List[str] = []
|
||||
|
||||
# Try live API if we have a token
|
||||
if access_token:
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Any
|
||||
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
||||
|
||||
COMMANDS = {
|
||||
"/help": "Show this help message",
|
||||
"/tools": "List available tools",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,46 +20,46 @@ from hermes_cli.colors import Colors, color
|
||||
def cron_list(show_all: bool = False):
|
||||
"""List all scheduled jobs."""
|
||||
from cron.jobs import list_jobs
|
||||
|
||||
|
||||
jobs = list_jobs(include_disabled=show_all)
|
||||
|
||||
|
||||
if not jobs:
|
||||
print(color("No scheduled jobs.", Colors.DIM))
|
||||
print(color("Create one with the /cron add command in chat, or via Telegram.", Colors.DIM))
|
||||
return
|
||||
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ Scheduled Jobs │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
print()
|
||||
|
||||
|
||||
for job in jobs:
|
||||
job_id = job.get("id", "?")[:8]
|
||||
name = job.get("name", "(unnamed)")
|
||||
schedule = job.get("schedule_display", job.get("schedule", {}).get("value", "?"))
|
||||
enabled = job.get("enabled", True)
|
||||
next_run = job.get("next_run_at", "?")
|
||||
|
||||
|
||||
repeat_info = job.get("repeat", {})
|
||||
repeat_times = repeat_info.get("times")
|
||||
repeat_completed = repeat_info.get("completed", 0)
|
||||
|
||||
|
||||
if repeat_times:
|
||||
repeat_str = f"{repeat_completed}/{repeat_times}"
|
||||
else:
|
||||
repeat_str = "∞"
|
||||
|
||||
|
||||
deliver = job.get("deliver", ["local"])
|
||||
if isinstance(deliver, str):
|
||||
deliver = [deliver]
|
||||
deliver_str = ", ".join(deliver)
|
||||
|
||||
|
||||
if not enabled:
|
||||
status = color("[disabled]", Colors.RED)
|
||||
else:
|
||||
status = color("[active]", Colors.GREEN)
|
||||
|
||||
|
||||
print(f" {color(job_id, Colors.YELLOW)} {status}")
|
||||
print(f" Name: {name}")
|
||||
print(f" Schedule: {schedule}")
|
||||
@@ -67,10 +67,9 @@ def cron_list(show_all: bool = False):
|
||||
print(f" Next run: {next_run}")
|
||||
print(f" Deliver: {deliver_str}")
|
||||
print()
|
||||
|
||||
|
||||
# Warn if gateway isn't running
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
|
||||
if not find_gateway_pids():
|
||||
print(color(" ⚠ Gateway is not running — jobs won't fire automatically.", Colors.YELLOW))
|
||||
print(color(" Start it with: hermes gateway install", Colors.DIM))
|
||||
@@ -80,7 +79,6 @@ def cron_list(show_all: bool = False):
|
||||
def cron_tick():
|
||||
"""Run due jobs once and exit."""
|
||||
from cron.scheduler import tick
|
||||
|
||||
tick(verbose=True)
|
||||
|
||||
|
||||
@@ -88,9 +86,9 @@ def cron_status():
|
||||
"""Show cron execution status."""
|
||||
from cron.jobs import list_jobs
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
|
||||
|
||||
print()
|
||||
|
||||
|
||||
pids = find_gateway_pids()
|
||||
if pids:
|
||||
print(color("✓ Gateway is running — cron jobs will fire automatically", Colors.GREEN))
|
||||
@@ -101,9 +99,9 @@ def cron_status():
|
||||
print(" To enable automatic execution:")
|
||||
print(" hermes gateway install # Install as system service (recommended)")
|
||||
print(" hermes gateway # Or run in foreground")
|
||||
|
||||
|
||||
print()
|
||||
|
||||
|
||||
jobs = list_jobs(include_disabled=False)
|
||||
if jobs:
|
||||
next_runs = [j.get("next_run_at") for j in jobs if j.get("next_run_at")]
|
||||
@@ -112,24 +110,24 @@ def cron_status():
|
||||
print(f" Next run: {min(next_runs)}")
|
||||
else:
|
||||
print(" No active jobs")
|
||||
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def cron_command(args):
|
||||
"""Handle cron subcommands."""
|
||||
subcmd = getattr(args, "cron_command", None)
|
||||
|
||||
subcmd = getattr(args, 'cron_command', None)
|
||||
|
||||
if subcmd is None or subcmd == "list":
|
||||
show_all = getattr(args, "all", False)
|
||||
show_all = getattr(args, 'all', False)
|
||||
cron_list(show_all)
|
||||
|
||||
|
||||
elif subcmd == "tick":
|
||||
cron_tick()
|
||||
|
||||
|
||||
elif subcmd == "status":
|
||||
cron_status()
|
||||
|
||||
|
||||
else:
|
||||
print(f"Unknown cron command: {subcmd}")
|
||||
print("Usage: hermes cron [list|status|tick]")
|
||||
|
||||
@@ -5,18 +5,18 @@ Diagnoses issues with Hermes Agent setup.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.config import get_env_path, get_hermes_home, get_project_root
|
||||
from hermes_cli.config import get_project_root, get_hermes_home, get_env_path
|
||||
|
||||
PROJECT_ROOT = get_project_root()
|
||||
HERMES_HOME = get_hermes_home()
|
||||
|
||||
# Load environment variables from ~/.hermes/.env so API key checks work
|
||||
from dotenv import load_dotenv
|
||||
|
||||
_env_path = get_env_path()
|
||||
if _env_path.exists():
|
||||
try:
|
||||
@@ -33,6 +33,7 @@ os.environ.setdefault("MSWEA_SILENT_STARTUP", "1")
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
|
||||
_PROVIDER_ENV_HINTS = (
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
@@ -55,38 +56,35 @@ def _has_provider_env_config(content: str) -> bool:
|
||||
def check_ok(text: str, detail: str = ""):
|
||||
print(f" {color('✓', Colors.GREEN)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
|
||||
def check_warn(text: str, detail: str = ""):
|
||||
print(f" {color('⚠', Colors.YELLOW)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
|
||||
def check_fail(text: str, detail: str = ""):
|
||||
print(f" {color('✗', Colors.RED)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
|
||||
def check_info(text: str):
|
||||
print(f" {color('→', Colors.CYAN)} {text}")
|
||||
|
||||
|
||||
def run_doctor(args):
|
||||
"""Run diagnostic checks."""
|
||||
should_fix = getattr(args, "fix", False)
|
||||
|
||||
should_fix = getattr(args, 'fix', False)
|
||||
|
||||
issues = []
|
||||
manual_issues = [] # issues that can't be auto-fixed
|
||||
fixed_count = 0
|
||||
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ 🩺 Hermes Doctor │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Python version
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Python Environment", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
py_version = sys.version_info
|
||||
if py_version >= (3, 11):
|
||||
check_ok(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}")
|
||||
@@ -98,20 +96,20 @@ def run_doctor(args):
|
||||
else:
|
||||
check_fail(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}", "(3.10+ required)")
|
||||
issues.append("Upgrade Python to 3.10+")
|
||||
|
||||
|
||||
# Check if in virtual environment
|
||||
in_venv = sys.prefix != sys.base_prefix
|
||||
if in_venv:
|
||||
check_ok("Virtual environment active")
|
||||
else:
|
||||
check_warn("Not in virtual environment", "(recommended)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Required packages
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Required Packages", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
required_packages = [
|
||||
("openai", "OpenAI SDK"),
|
||||
("rich", "Rich (terminal UI)"),
|
||||
@@ -119,13 +117,13 @@ def run_doctor(args):
|
||||
("yaml", "PyYAML"),
|
||||
("httpx", "HTTPX"),
|
||||
]
|
||||
|
||||
|
||||
optional_packages = [
|
||||
("croniter", "Croniter (cron expressions)"),
|
||||
("telegram", "python-telegram-bot"),
|
||||
("discord", "discord.py"),
|
||||
]
|
||||
|
||||
|
||||
for module, name in required_packages:
|
||||
try:
|
||||
__import__(module)
|
||||
@@ -133,25 +131,25 @@ def run_doctor(args):
|
||||
except ImportError:
|
||||
check_fail(name, "(missing)")
|
||||
issues.append(f"Install {name}: uv pip install {module}")
|
||||
|
||||
|
||||
for module, name in optional_packages:
|
||||
try:
|
||||
__import__(module)
|
||||
check_ok(name, "(optional)")
|
||||
except ImportError:
|
||||
check_warn(name, "(optional, not installed)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Configuration files
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Configuration Files", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
# Check ~/.hermes/.env (primary location for user config)
|
||||
env_path = HERMES_HOME / ".env"
|
||||
env_path = HERMES_HOME / '.env'
|
||||
if env_path.exists():
|
||||
check_ok("~/.hermes/.env file exists")
|
||||
|
||||
|
||||
# Check for common issues
|
||||
content = env_path.read_text()
|
||||
if _has_provider_env_config(content):
|
||||
@@ -161,7 +159,7 @@ def run_doctor(args):
|
||||
issues.append("Run 'hermes setup' to configure API keys")
|
||||
else:
|
||||
# Also check project root as fallback
|
||||
fallback_env = PROJECT_ROOT / ".env"
|
||||
fallback_env = PROJECT_ROOT / '.env'
|
||||
if fallback_env.exists():
|
||||
check_ok(".env file exists (in project directory)")
|
||||
else:
|
||||
@@ -175,17 +173,17 @@ def run_doctor(args):
|
||||
else:
|
||||
check_info("Run 'hermes setup' to create one")
|
||||
issues.append("Run 'hermes setup' to create .env")
|
||||
|
||||
|
||||
# Check ~/.hermes/config.yaml (primary) or project cli-config.yaml (fallback)
|
||||
config_path = HERMES_HOME / "config.yaml"
|
||||
config_path = HERMES_HOME / 'config.yaml'
|
||||
if config_path.exists():
|
||||
check_ok("~/.hermes/config.yaml exists")
|
||||
else:
|
||||
fallback_config = PROJECT_ROOT / "cli-config.yaml"
|
||||
fallback_config = PROJECT_ROOT / 'cli-config.yaml'
|
||||
if fallback_config.exists():
|
||||
check_ok("cli-config.yaml exists (in project directory)")
|
||||
else:
|
||||
example_config = PROJECT_ROOT / "cli-config.yaml.example"
|
||||
example_config = PROJECT_ROOT / 'cli-config.yaml.example'
|
||||
if should_fix and example_config.exists():
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(str(example_config), str(config_path))
|
||||
@@ -196,7 +194,7 @@ def run_doctor(args):
|
||||
manual_issues.append("Create ~/.hermes/config.yaml manually")
|
||||
else:
|
||||
check_warn("config.yaml not found", "(using defaults)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Auth providers
|
||||
# =========================================================================
|
||||
@@ -204,7 +202,7 @@ def run_doctor(args):
|
||||
print(color("◆ Auth Providers", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import get_codex_auth_status, get_nous_auth_status
|
||||
from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status
|
||||
|
||||
nous_status = get_nous_auth_status()
|
||||
if nous_status.get("logged_in"):
|
||||
@@ -232,7 +230,7 @@ def run_doctor(args):
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Directory Structure", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
hermes_home = HERMES_HOME
|
||||
if hermes_home.exists():
|
||||
check_ok("~/.hermes directory exists")
|
||||
@@ -243,7 +241,7 @@ def run_doctor(args):
|
||||
fixed_count += 1
|
||||
else:
|
||||
check_warn("~/.hermes not found", "(will be created on first use)")
|
||||
|
||||
|
||||
# Check expected subdirectories
|
||||
expected_subdirs = ["cron", "sessions", "logs", "skills", "memories"]
|
||||
for subdir_name in expected_subdirs:
|
||||
@@ -257,7 +255,7 @@ def run_doctor(args):
|
||||
fixed_count += 1
|
||||
else:
|
||||
check_warn(f"~/.hermes/{subdir_name}/ not found", "(will be created on first use)")
|
||||
|
||||
|
||||
# Check for SOUL.md persona file
|
||||
soul_path = hermes_home / "SOUL.md"
|
||||
if soul_path.exists():
|
||||
@@ -280,7 +278,7 @@ def run_doctor(args):
|
||||
)
|
||||
check_ok("Created ~/.hermes/SOUL.md with basic template")
|
||||
fixed_count += 1
|
||||
|
||||
|
||||
# Check memory directory
|
||||
memories_dir = hermes_home / "memories"
|
||||
if memories_dir.exists():
|
||||
@@ -303,13 +301,12 @@ def run_doctor(args):
|
||||
memories_dir.mkdir(parents=True, exist_ok=True)
|
||||
check_ok("Created ~/.hermes/memories/")
|
||||
fixed_count += 1
|
||||
|
||||
|
||||
# Check SQLite session store
|
||||
state_db_path = hermes_home / "state.db"
|
||||
if state_db_path.exists():
|
||||
try:
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(str(state_db_path))
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM sessions")
|
||||
count = cursor.fetchone()[0]
|
||||
@@ -319,26 +316,26 @@ def run_doctor(args):
|
||||
check_warn(f"~/.hermes/state.db exists but has issues: {e}")
|
||||
else:
|
||||
check_info("~/.hermes/state.db not created yet (will be created on first session)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: External tools
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ External Tools", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
# Git
|
||||
if shutil.which("git"):
|
||||
check_ok("git")
|
||||
else:
|
||||
check_warn("git not found", "(optional)")
|
||||
|
||||
|
||||
# ripgrep (optional, for faster file search)
|
||||
if shutil.which("rg"):
|
||||
check_ok("ripgrep (rg)", "(faster file search)")
|
||||
else:
|
||||
check_warn("ripgrep (rg) not found", "(file search uses grep fallback)")
|
||||
check_info("Install for faster search: sudo apt install ripgrep")
|
||||
|
||||
|
||||
# Docker (optional)
|
||||
terminal_env = os.getenv("TERMINAL_ENV", "local")
|
||||
if terminal_env == "docker":
|
||||
@@ -358,7 +355,7 @@ def run_doctor(args):
|
||||
check_ok("docker", "(optional)")
|
||||
else:
|
||||
check_warn("docker not found", "(optional)")
|
||||
|
||||
|
||||
# SSH (if using ssh backend)
|
||||
if terminal_env == "ssh":
|
||||
ssh_host = os.getenv("TERMINAL_SSH_HOST")
|
||||
@@ -367,7 +364,7 @@ def run_doctor(args):
|
||||
result = subprocess.run(
|
||||
["ssh", "-o", "ConnectTimeout=5", "-o", "BatchMode=yes", ssh_host, "echo ok"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
check_ok(f"SSH connection to {ssh_host}")
|
||||
@@ -377,7 +374,7 @@ def run_doctor(args):
|
||||
else:
|
||||
check_fail("TERMINAL_SSH_HOST not set", "(required for TERMINAL_ENV=ssh)")
|
||||
issues.append("Set TERMINAL_SSH_HOST in .env")
|
||||
|
||||
|
||||
# Daytona (if using daytona backend)
|
||||
if terminal_env == "daytona":
|
||||
daytona_key = os.getenv("DAYTONA_API_KEY")
|
||||
@@ -388,7 +385,6 @@ def run_doctor(args):
|
||||
issues.append("Set DAYTONA_API_KEY environment variable")
|
||||
try:
|
||||
from daytona import Daytona
|
||||
|
||||
check_ok("daytona SDK", "(installed)")
|
||||
except ImportError:
|
||||
check_fail("daytona SDK not installed", "(pip install daytona)")
|
||||
@@ -405,7 +401,7 @@ def run_doctor(args):
|
||||
check_warn("agent-browser not installed", "(run: npm install)")
|
||||
else:
|
||||
check_warn("Node.js not found", "(optional, needed for browser tools)")
|
||||
|
||||
|
||||
# npm audit for all Node.js packages
|
||||
if shutil.which("npm"):
|
||||
npm_dirs = [
|
||||
@@ -419,12 +415,9 @@ def run_doctor(args):
|
||||
audit_result = subprocess.run(
|
||||
["npm", "audit", "--json"],
|
||||
cwd=str(npm_dir),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
capture_output=True, text=True, timeout=30,
|
||||
)
|
||||
import json as _json
|
||||
|
||||
audit_data = _json.loads(audit_result.stdout) if audit_result.stdout.strip() else {}
|
||||
vuln_count = audit_data.get("metadata", {}).get("vulnerabilities", {})
|
||||
critical = vuln_count.get("critical", 0)
|
||||
@@ -436,7 +429,7 @@ def run_doctor(args):
|
||||
elif critical > 0 or high > 0:
|
||||
check_warn(
|
||||
f"{label} deps",
|
||||
f"({critical} critical, {high} high, {moderate} moderate — run: cd {npm_dir} && npm audit fix)",
|
||||
f"({critical} critical, {high} high, {moderate} moderate — run: cd {npm_dir} && npm audit fix)"
|
||||
)
|
||||
issues.append(f"{label} has {total} npm vulnerability(ies)")
|
||||
else:
|
||||
@@ -449,50 +442,47 @@ def run_doctor(args):
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ API Connectivity", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if openrouter_key:
|
||||
print(" Checking OpenRouter API...", end="", flush=True)
|
||||
try:
|
||||
import httpx
|
||||
|
||||
response = httpx.get(
|
||||
OPENROUTER_MODELS_URL, headers={"Authorization": f"Bearer {openrouter_key}"}, timeout=10
|
||||
OPENROUTER_MODELS_URL,
|
||||
headers={"Authorization": f"Bearer {openrouter_key}"},
|
||||
timeout=10
|
||||
)
|
||||
if response.status_code == 200:
|
||||
print(f"\r {color('✓', Colors.GREEN)} OpenRouter API ")
|
||||
elif response.status_code == 401:
|
||||
print(
|
||||
f"\r {color('✗', Colors.RED)} OpenRouter API {color('(invalid API key)', Colors.DIM)} "
|
||||
)
|
||||
print(f"\r {color('✗', Colors.RED)} OpenRouter API {color('(invalid API key)', Colors.DIM)} ")
|
||||
issues.append("Check OPENROUTER_API_KEY in .env")
|
||||
else:
|
||||
print(
|
||||
f"\r {color('✗', Colors.RED)} OpenRouter API {color(f'(HTTP {response.status_code})', Colors.DIM)} "
|
||||
)
|
||||
print(f"\r {color('✗', Colors.RED)} OpenRouter API {color(f'(HTTP {response.status_code})', Colors.DIM)} ")
|
||||
except Exception as e:
|
||||
print(f"\r {color('✗', Colors.RED)} OpenRouter API {color(f'({e})', Colors.DIM)} ")
|
||||
issues.append("Check network connectivity")
|
||||
else:
|
||||
check_warn("OpenRouter API", "(not configured)")
|
||||
|
||||
|
||||
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if anthropic_key:
|
||||
print(" Checking Anthropic API...", end="", flush=True)
|
||||
try:
|
||||
import httpx
|
||||
|
||||
response = httpx.get(
|
||||
"https://api.anthropic.com/v1/models",
|
||||
headers={"x-api-key": anthropic_key, "anthropic-version": "2023-06-01"},
|
||||
timeout=10,
|
||||
headers={
|
||||
"x-api-key": anthropic_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
},
|
||||
timeout=10
|
||||
)
|
||||
if response.status_code == 200:
|
||||
print(f"\r {color('✓', Colors.GREEN)} Anthropic API ")
|
||||
elif response.status_code == 401:
|
||||
print(
|
||||
f"\r {color('✗', Colors.RED)} Anthropic API {color('(invalid API key)', Colors.DIM)} "
|
||||
)
|
||||
print(f"\r {color('✗', Colors.RED)} Anthropic API {color('(invalid API key)', Colors.DIM)} ")
|
||||
else:
|
||||
msg = "(couldn't verify)"
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} Anthropic API {color(msg, Colors.DIM)} ")
|
||||
@@ -501,15 +491,10 @@ def run_doctor(args):
|
||||
|
||||
# -- API-key providers (Z.AI/GLM, Kimi, MiniMax, MiniMax-CN) --
|
||||
_apikey_providers = [
|
||||
(
|
||||
"Z.AI / GLM",
|
||||
("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
"https://api.z.ai/api/paas/v4/models",
|
||||
"GLM_BASE_URL",
|
||||
),
|
||||
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
|
||||
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
|
||||
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
|
||||
("Z.AI / GLM", ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), "https://api.z.ai/api/paas/v4/models", "GLM_BASE_URL"),
|
||||
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
|
||||
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
|
||||
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
|
||||
]
|
||||
for _pname, _env_vars, _default_url, _base_env in _apikey_providers:
|
||||
_key = ""
|
||||
@@ -522,7 +507,6 @@ def run_doctor(args):
|
||||
print(f" Checking {_pname} API...", end="", flush=True)
|
||||
try:
|
||||
import httpx
|
||||
|
||||
_base = os.getenv(_base_env, "")
|
||||
# Auto-detect Kimi Code keys (sk-kimi-) → api.kimi.com
|
||||
if not _base and _key.startswith("sk-kimi-"):
|
||||
@@ -542,9 +526,7 @@ def run_doctor(args):
|
||||
print(f"\r {color('✗', Colors.RED)} {_label} {color('(invalid API key)', Colors.DIM)} ")
|
||||
issues.append(f"Check {_env_vars[0]} in .env")
|
||||
else:
|
||||
print(
|
||||
f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} "
|
||||
)
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} ")
|
||||
except Exception as _e:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'({_e})', Colors.DIM)} ")
|
||||
|
||||
@@ -553,7 +535,7 @@ def run_doctor(args):
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Submodules", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
# mini-swe-agent (terminal tool backend)
|
||||
mini_swe_dir = PROJECT_ROOT / "mini-swe-agent"
|
||||
if mini_swe_dir.exists() and (mini_swe_dir / "pyproject.toml").exists():
|
||||
@@ -565,7 +547,7 @@ def run_doctor(args):
|
||||
issues.append("Install mini-swe-agent: uv pip install -e ./mini-swe-agent")
|
||||
else:
|
||||
check_warn("mini-swe-agent not found", "(run: git submodule update --init --recursive)")
|
||||
|
||||
|
||||
# tinker-atropos (RL training backend)
|
||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
@@ -580,24 +562,24 @@ def run_doctor(args):
|
||||
check_warn("tinker-atropos requires Python 3.11+", f"(current: {py_version.major}.{py_version.minor})")
|
||||
else:
|
||||
check_warn("tinker-atropos not found", "(run: git submodule update --init --recursive)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Tool Availability
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Tool Availability", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
try:
|
||||
# Add project root to path for imports
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
from model_tools import TOOLSET_REQUIREMENTS, check_tool_availability
|
||||
|
||||
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
|
||||
|
||||
available, unavailable = check_tool_availability()
|
||||
|
||||
|
||||
for tid in available:
|
||||
info = TOOLSET_REQUIREMENTS.get(tid, {})
|
||||
check_ok(info.get("name", tid))
|
||||
|
||||
|
||||
for item in unavailable:
|
||||
env_vars = item.get("missing_vars") or item.get("env_vars") or []
|
||||
if env_vars:
|
||||
@@ -612,7 +594,7 @@ def run_doctor(args):
|
||||
issues.append("Run 'hermes setup' to configure missing API keys for full tool access")
|
||||
except Exception as e:
|
||||
check_warn("Could not check tool availability", f"({e})")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Skills Hub
|
||||
# =========================================================================
|
||||
@@ -626,7 +608,6 @@ def run_doctor(args):
|
||||
if lock_file.exists():
|
||||
try:
|
||||
import json
|
||||
|
||||
lock_data = json.loads(lock_file.read_text())
|
||||
count = len(lock_data.get("installed", {}))
|
||||
check_ok(f"Lock file OK ({count} hub-installed skill(s))")
|
||||
@@ -640,7 +621,6 @@ def run_doctor(args):
|
||||
check_warn("Skills Hub directory not initialized", "(run: hermes skills list)")
|
||||
|
||||
from hermes_cli.config import get_env_value
|
||||
|
||||
github_token = get_env_value("GITHUB_TOKEN") or get_env_value("GH_TOKEN")
|
||||
if github_token:
|
||||
check_ok("GitHub token configured (authenticated API access)")
|
||||
@@ -676,5 +656,5 @@ def run_doctor(args):
|
||||
else:
|
||||
print(color("─" * 60, Colors.GREEN))
|
||||
print(color(" All checks passed! 🎉", Colors.GREEN, Colors.BOLD))
|
||||
|
||||
|
||||
print()
|
||||
|
||||
@@ -13,24 +13,18 @@ from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_cli.config import get_env_value, save_env_value
|
||||
from hermes_cli.setup import (
|
||||
print_error,
|
||||
print_header,
|
||||
print_info,
|
||||
print_success,
|
||||
print_warning,
|
||||
prompt,
|
||||
prompt_choice,
|
||||
prompt_yes_no,
|
||||
print_header, print_info, print_success, print_warning, print_error,
|
||||
prompt, prompt_choice, prompt_yes_no,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Process Management (for manual gateway runs)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def find_gateway_pids() -> list:
|
||||
"""Find PIDs of running gateway processes."""
|
||||
pids = []
|
||||
@@ -44,16 +38,17 @@ def find_gateway_pids() -> list:
|
||||
if is_windows():
|
||||
# Windows: use wmic to search command lines
|
||||
result = subprocess.run(
|
||||
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"], capture_output=True, text=True
|
||||
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
# Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n"
|
||||
current_cmd = ""
|
||||
for line in result.stdout.split("\n"):
|
||||
for line in result.stdout.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith("CommandLine="):
|
||||
current_cmd = line[len("CommandLine=") :]
|
||||
current_cmd = line[len("CommandLine="):]
|
||||
elif line.startswith("ProcessId="):
|
||||
pid_str = line[len("ProcessId=") :]
|
||||
pid_str = line[len("ProcessId="):]
|
||||
if any(p in current_cmd for p in patterns):
|
||||
try:
|
||||
pid = int(pid_str)
|
||||
@@ -63,10 +58,14 @@ def find_gateway_pids() -> list:
|
||||
pass
|
||||
current_cmd = ""
|
||||
else:
|
||||
result = subprocess.run(["ps", "aux"], capture_output=True, text=True)
|
||||
for line in result.stdout.split("\n"):
|
||||
result = subprocess.run(
|
||||
["ps", "aux"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
# Skip grep and current process
|
||||
if "grep" in line or str(os.getpid()) in line:
|
||||
if 'grep' in line or str(os.getpid()) in line:
|
||||
continue
|
||||
for pattern in patterns:
|
||||
if pattern in line:
|
||||
@@ -89,7 +88,7 @@ def kill_gateway_processes(force: bool = False) -> int:
|
||||
"""Kill any running gateway processes. Returns count killed."""
|
||||
pids = find_gateway_pids()
|
||||
killed = 0
|
||||
|
||||
|
||||
for pid in pids:
|
||||
try:
|
||||
if force and not is_windows():
|
||||
@@ -102,20 +101,18 @@ def kill_gateway_processes(force: bool = False) -> int:
|
||||
pass
|
||||
except PermissionError:
|
||||
print(f"⚠ Permission denied to kill PID {pid}")
|
||||
|
||||
|
||||
return killed
|
||||
|
||||
|
||||
def is_linux() -> bool:
|
||||
return sys.platform.startswith("linux")
|
||||
|
||||
return sys.platform.startswith('linux')
|
||||
|
||||
def is_macos() -> bool:
|
||||
return sys.platform == "darwin"
|
||||
|
||||
return sys.platform == 'darwin'
|
||||
|
||||
def is_windows() -> bool:
|
||||
return sys.platform == "win32"
|
||||
return sys.platform == 'win32'
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -125,15 +122,12 @@ def is_windows() -> bool:
|
||||
SERVICE_NAME = "hermes-gateway"
|
||||
SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
|
||||
|
||||
|
||||
def get_systemd_unit_path() -> Path:
|
||||
return Path.home() / ".config" / "systemd" / "user" / f"{SERVICE_NAME}.service"
|
||||
|
||||
|
||||
def get_launchd_plist_path() -> Path:
|
||||
return Path.home() / "Library" / "LaunchAgents" / "ai.hermes.gateway.plist"
|
||||
|
||||
|
||||
def get_python_path() -> str:
|
||||
if is_windows():
|
||||
venv_python = PROJECT_ROOT / "venv" / "Scripts" / "python.exe"
|
||||
@@ -143,16 +137,14 @@ def get_python_path() -> str:
|
||||
return str(venv_python)
|
||||
return sys.executable
|
||||
|
||||
|
||||
def get_hermes_cli_path() -> str:
|
||||
"""Get the path to the hermes CLI."""
|
||||
# Check if installed via pip
|
||||
import shutil
|
||||
|
||||
hermes_bin = shutil.which("hermes")
|
||||
if hermes_bin:
|
||||
return hermes_bin
|
||||
|
||||
|
||||
# Fallback to direct module execution
|
||||
return f"{get_python_path()} -m hermes_cli.main"
|
||||
|
||||
@@ -161,10 +153,8 @@ def get_hermes_cli_path() -> str:
|
||||
# Systemd (Linux)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def generate_systemd_unit() -> str:
|
||||
import shutil
|
||||
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
venv_dir = str(PROJECT_ROOT / "venv")
|
||||
@@ -173,7 +163,7 @@ def generate_systemd_unit() -> str:
|
||||
|
||||
# Build a PATH that includes the venv, node_modules, and standard system dirs
|
||||
sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
|
||||
|
||||
hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main"
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
@@ -198,62 +188,56 @@ StandardError=journal
|
||||
WantedBy=default.target
|
||||
"""
|
||||
|
||||
|
||||
def systemd_install(force: bool = False):
|
||||
unit_path = get_systemd_unit_path()
|
||||
|
||||
|
||||
if unit_path.exists() and not force:
|
||||
print(f"Service already installed at: {unit_path}")
|
||||
print("Use --force to reinstall")
|
||||
return
|
||||
|
||||
|
||||
unit_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Installing systemd service to: {unit_path}")
|
||||
unit_path.write_text(generate_systemd_unit())
|
||||
|
||||
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
subprocess.run(["systemctl", "--user", "enable", SERVICE_NAME], check=True)
|
||||
|
||||
|
||||
print()
|
||||
print("✓ Service installed and enabled!")
|
||||
print()
|
||||
print("Next steps:")
|
||||
print(" hermes gateway start # Start the service")
|
||||
print(" hermes gateway status # Check status")
|
||||
print(f" hermes gateway start # Start the service")
|
||||
print(f" hermes gateway status # Check status")
|
||||
print(f" journalctl --user -u {SERVICE_NAME} -f # View logs")
|
||||
print()
|
||||
print("To enable lingering (keeps running after logout):")
|
||||
print(" sudo loginctl enable-linger $USER")
|
||||
|
||||
|
||||
def systemd_uninstall():
|
||||
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=False)
|
||||
subprocess.run(["systemctl", "--user", "disable", SERVICE_NAME], check=False)
|
||||
|
||||
|
||||
unit_path = get_systemd_unit_path()
|
||||
if unit_path.exists():
|
||||
unit_path.unlink()
|
||||
print(f"✓ Removed {unit_path}")
|
||||
|
||||
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
print("✓ Service uninstalled")
|
||||
|
||||
|
||||
def systemd_start():
|
||||
subprocess.run(["systemctl", "--user", "start", SERVICE_NAME], check=True)
|
||||
print("✓ Service started")
|
||||
|
||||
|
||||
def systemd_stop():
|
||||
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=True)
|
||||
print("✓ Service stopped")
|
||||
|
||||
|
||||
def systemd_restart():
|
||||
subprocess.run(["systemctl", "--user", "restart", SERVICE_NAME], check=True)
|
||||
print("✓ Service restarted")
|
||||
|
||||
|
||||
def systemd_status(deep: bool = False):
|
||||
# Check if service unit file exists
|
||||
unit_path = get_systemd_unit_path()
|
||||
@@ -261,45 +245,54 @@ def systemd_status(deep: bool = False):
|
||||
print("✗ Gateway service is not installed")
|
||||
print(" Run: hermes gateway install")
|
||||
return
|
||||
|
||||
|
||||
# Show detailed status first
|
||||
subprocess.run(["systemctl", "--user", "status", SERVICE_NAME, "--no-pager"], capture_output=False)
|
||||
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "status", SERVICE_NAME, "--no-pager"],
|
||||
capture_output=False
|
||||
)
|
||||
|
||||
# Check if service is active
|
||||
result = subprocess.run(["systemctl", "--user", "is-active", SERVICE_NAME], capture_output=True, text=True)
|
||||
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", SERVICE_NAME],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
status = result.stdout.strip()
|
||||
|
||||
|
||||
if status == "active":
|
||||
print("✓ Gateway service is running")
|
||||
else:
|
||||
print("✗ Gateway service is stopped")
|
||||
print(" Run: hermes gateway start")
|
||||
|
||||
|
||||
if deep:
|
||||
print()
|
||||
print("Recent logs:")
|
||||
subprocess.run(["journalctl", "--user", "-u", SERVICE_NAME, "-n", "20", "--no-pager"])
|
||||
subprocess.run([
|
||||
"journalctl", "--user", "-u", SERVICE_NAME,
|
||||
"-n", "20", "--no-pager"
|
||||
])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Launchd (macOS)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def generate_launchd_plist() -> str:
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
log_dir = Path.home() / ".hermes" / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
return f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>ai.hermes.gateway</string>
|
||||
|
||||
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>{python_path}</string>
|
||||
@@ -308,43 +301,42 @@ def generate_launchd_plist() -> str:
|
||||
<string>gateway</string>
|
||||
<string>run</string>
|
||||
</array>
|
||||
|
||||
|
||||
<key>WorkingDirectory</key>
|
||||
<string>{working_dir}</string>
|
||||
|
||||
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
|
||||
|
||||
<key>KeepAlive</key>
|
||||
<dict>
|
||||
<key>SuccessfulExit</key>
|
||||
<false/>
|
||||
</dict>
|
||||
|
||||
|
||||
<key>StandardOutPath</key>
|
||||
<string>{log_dir}/gateway.log</string>
|
||||
|
||||
|
||||
<key>StandardErrorPath</key>
|
||||
<string>{log_dir}/gateway.error.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
"""
|
||||
|
||||
|
||||
def launchd_install(force: bool = False):
|
||||
plist_path = get_launchd_plist_path()
|
||||
|
||||
|
||||
if plist_path.exists() and not force:
|
||||
print(f"Service already installed at: {plist_path}")
|
||||
print("Use --force to reinstall")
|
||||
return
|
||||
|
||||
|
||||
plist_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Installing launchd service to: {plist_path}")
|
||||
plist_path.write_text(generate_launchd_plist())
|
||||
|
||||
|
||||
subprocess.run(["launchctl", "load", str(plist_path)], check=True)
|
||||
|
||||
|
||||
print()
|
||||
print("✓ Service installed and loaded!")
|
||||
print()
|
||||
@@ -352,42 +344,41 @@ def launchd_install(force: bool = False):
|
||||
print(" hermes gateway status # Check status")
|
||||
print(" tail -f ~/.hermes/logs/gateway.log # View logs")
|
||||
|
||||
|
||||
def launchd_uninstall():
|
||||
plist_path = get_launchd_plist_path()
|
||||
subprocess.run(["launchctl", "unload", str(plist_path)], check=False)
|
||||
|
||||
|
||||
if plist_path.exists():
|
||||
plist_path.unlink()
|
||||
print(f"✓ Removed {plist_path}")
|
||||
|
||||
|
||||
print("✓ Service uninstalled")
|
||||
|
||||
|
||||
def launchd_start():
|
||||
subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True)
|
||||
print("✓ Service started")
|
||||
|
||||
|
||||
def launchd_stop():
|
||||
subprocess.run(["launchctl", "stop", "ai.hermes.gateway"], check=True)
|
||||
print("✓ Service stopped")
|
||||
|
||||
|
||||
def launchd_restart():
|
||||
launchd_stop()
|
||||
launchd_start()
|
||||
|
||||
|
||||
def launchd_status(deep: bool = False):
|
||||
result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True)
|
||||
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("✓ Gateway service is loaded")
|
||||
print(result.stdout)
|
||||
else:
|
||||
print("✗ Gateway service is not loaded")
|
||||
|
||||
|
||||
if deep:
|
||||
log_file = Path.home() / ".hermes" / "logs" / "gateway.log"
|
||||
if log_file.exists():
|
||||
@@ -400,10 +391,9 @@ def launchd_status(deep: bool = False):
|
||||
# Gateway Runner
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def run_gateway(verbose: bool = False, replace: bool = False):
|
||||
"""Run the gateway in foreground.
|
||||
|
||||
|
||||
Args:
|
||||
verbose: Enable verbose logging output.
|
||||
replace: If True, kill any existing gateway instance before starting.
|
||||
@@ -411,9 +401,9 @@ def run_gateway(verbose: bool = False, replace: bool = False):
|
||||
hasn't fully exited yet.
|
||||
"""
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
from gateway.run import start_gateway
|
||||
|
||||
|
||||
print("┌─────────────────────────────────────────────────────────┐")
|
||||
print("│ ⚕ Hermes Gateway Starting... │")
|
||||
print("├─────────────────────────────────────────────────────────┤")
|
||||
@@ -421,7 +411,7 @@ def run_gateway(verbose: bool = False, replace: bool = False):
|
||||
print("│ Press Ctrl+C to stop │")
|
||||
print("└─────────────────────────────────────────────────────────┘")
|
||||
print()
|
||||
|
||||
|
||||
# Exit with code 1 if gateway fails to connect any platform,
|
||||
# so systemd Restart=on-failure will retry on transient errors
|
||||
success = asyncio.run(start_gateway(replace=replace))
|
||||
@@ -448,25 +438,13 @@ _PLATFORMS = [
|
||||
"4. To find your user ID: message @userinfobot — it replies with your numeric ID",
|
||||
],
|
||||
"vars": [
|
||||
{
|
||||
"name": "TELEGRAM_BOT_TOKEN",
|
||||
"prompt": "Bot token",
|
||||
"password": True,
|
||||
"help": "Paste the token from @BotFather (step 3 above).",
|
||||
},
|
||||
{
|
||||
"name": "TELEGRAM_ALLOWED_USERS",
|
||||
"prompt": "Allowed user IDs (comma-separated)",
|
||||
"password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your user ID from step 4 above.",
|
||||
},
|
||||
{
|
||||
"name": "TELEGRAM_HOME_CHANNEL",
|
||||
"prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)",
|
||||
"password": False,
|
||||
"help": "For DMs, this is your user ID. You can set it later by typing /set-home in chat.",
|
||||
},
|
||||
{"name": "TELEGRAM_BOT_TOKEN", "prompt": "Bot token", "password": True,
|
||||
"help": "Paste the token from @BotFather (step 3 above)."},
|
||||
{"name": "TELEGRAM_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your user ID from step 4 above."},
|
||||
{"name": "TELEGRAM_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False,
|
||||
"help": "For DMs, this is your user ID. You can set it later by typing /set-home in chat."},
|
||||
],
|
||||
},
|
||||
{
|
||||
@@ -488,25 +466,13 @@ _PLATFORMS = [
|
||||
" then right-click your name → Copy ID",
|
||||
],
|
||||
"vars": [
|
||||
{
|
||||
"name": "DISCORD_BOT_TOKEN",
|
||||
"prompt": "Bot token",
|
||||
"password": True,
|
||||
"help": "Paste the token from step 2 above.",
|
||||
},
|
||||
{
|
||||
"name": "DISCORD_ALLOWED_USERS",
|
||||
"prompt": "Allowed user IDs or usernames (comma-separated)",
|
||||
"password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your user ID from step 5 above.",
|
||||
},
|
||||
{
|
||||
"name": "DISCORD_HOME_CHANNEL",
|
||||
"prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)",
|
||||
"password": False,
|
||||
"help": "Right-click a channel → Copy Channel ID (requires Developer Mode).",
|
||||
},
|
||||
{"name": "DISCORD_BOT_TOKEN", "prompt": "Bot token", "password": True,
|
||||
"help": "Paste the token from step 2 above."},
|
||||
{"name": "DISCORD_ALLOWED_USERS", "prompt": "Allowed user IDs or usernames (comma-separated)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your user ID from step 5 above."},
|
||||
{"name": "DISCORD_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False,
|
||||
"help": "Right-click a channel → Copy Channel ID (requires Developer Mode)."},
|
||||
],
|
||||
},
|
||||
{
|
||||
@@ -516,40 +482,23 @@ _PLATFORMS = [
|
||||
"token_var": "SLACK_BOT_TOKEN",
|
||||
"setup_instructions": [
|
||||
"1. Go to https://api.slack.com/apps → Create New App → From Scratch",
|
||||
"2. Enable Socket Mode: Settings → Socket Mode → Enable",
|
||||
" Create an App-Level Token with scope: connections:write → copy xapp-... token",
|
||||
"3. Add Bot Token Scopes: Features → OAuth & Permissions → Scopes",
|
||||
" Required: chat:write, app_mentions:read, channels:history, channels:read,",
|
||||
" groups:history, im:history, im:read, im:write, users:read, files:write",
|
||||
"4. Subscribe to Events: Features → Event Subscriptions → Enable",
|
||||
" Required events: message.im, message.channels, app_mention",
|
||||
" Optional: message.groups (for private channels)",
|
||||
" ⚠ Without message.channels the bot will ONLY work in DMs!",
|
||||
"5. Install to Workspace: Settings → Install App → copy xoxb-... token",
|
||||
"6. Reinstall the app after any scope or event changes",
|
||||
"2. Enable Socket Mode: App Settings → Socket Mode → Enable",
|
||||
"3. Get Bot Token: OAuth & Permissions → Install to Workspace → copy xoxb-... token",
|
||||
"4. Get App Token: Basic Information → App-Level Tokens → Generate",
|
||||
" Name it anything, add scope: connections:write → copy xapp-... token",
|
||||
"5. Add bot scopes: OAuth & Permissions → Scopes → chat:write, im:history,",
|
||||
" im:read, im:write, channels:history, channels:read",
|
||||
"6. Reinstall the app to your workspace after adding scopes",
|
||||
"7. Find your user ID: click your profile → three dots → Copy member ID",
|
||||
"8. Invite the bot to channels: /invite @YourBot",
|
||||
],
|
||||
"vars": [
|
||||
{
|
||||
"name": "SLACK_BOT_TOKEN",
|
||||
"prompt": "Bot Token (xoxb-...)",
|
||||
"password": True,
|
||||
"help": "Paste the bot token from step 3 above.",
|
||||
},
|
||||
{
|
||||
"name": "SLACK_APP_TOKEN",
|
||||
"prompt": "App Token (xapp-...)",
|
||||
"password": True,
|
||||
"help": "Paste the app-level token from step 4 above.",
|
||||
},
|
||||
{
|
||||
"name": "SLACK_ALLOWED_USERS",
|
||||
"prompt": "Allowed user IDs (comma-separated)",
|
||||
"password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your member ID from step 7 above.",
|
||||
},
|
||||
{"name": "SLACK_BOT_TOKEN", "prompt": "Bot Token (xoxb-...)", "password": True,
|
||||
"help": "Paste the bot token from step 3 above."},
|
||||
{"name": "SLACK_APP_TOKEN", "prompt": "App Token (xapp-...)", "password": True,
|
||||
"help": "Paste the app-level token from step 4 above."},
|
||||
{"name": "SLACK_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your member ID from step 7 above."},
|
||||
],
|
||||
},
|
||||
{
|
||||
@@ -628,14 +577,14 @@ def _setup_standard_platform(platform: dict):
|
||||
|
||||
# Allowlist fields get special handling for the deny-by-default security model
|
||||
if var.get("is_allowlist"):
|
||||
print_info(" The gateway DENIES all users by default for security.")
|
||||
print_info(" Enter user IDs to create an allowlist, or leave empty")
|
||||
print_info(" and you'll be asked about open access next.")
|
||||
print_info(f" The gateway DENIES all users by default for security.")
|
||||
print_info(f" Enter user IDs to create an allowlist, or leave empty")
|
||||
print_info(f" and you'll be asked about open access next.")
|
||||
value = prompt(f" {var['prompt']}", password=False)
|
||||
if value:
|
||||
cleaned = value.replace(" ", "")
|
||||
save_env_value(var["name"], cleaned)
|
||||
print_success(" Saved — only these users can interact with the bot.")
|
||||
print_success(f" Saved — only these users can interact with the bot.")
|
||||
allowed_val_set = cleaned
|
||||
else:
|
||||
# No allowlist — ask about open access vs DM pairing
|
||||
@@ -664,7 +613,7 @@ def _setup_standard_platform(platform: dict):
|
||||
print_warning(f" Skipped — {label} won't work without this.")
|
||||
return
|
||||
else:
|
||||
print_info(" Skipped (can configure later)")
|
||||
print_info(f" Skipped (can configure later)")
|
||||
|
||||
# If an allowlist was set and home channel wasn't, offer to reuse
|
||||
# the first user ID (common for Telegram DMs).
|
||||
@@ -682,10 +631,8 @@ def _setup_standard_platform(platform: dict):
|
||||
|
||||
def _setup_whatsapp():
|
||||
"""Delegate to the existing WhatsApp setup flow."""
|
||||
import argparse
|
||||
|
||||
from hermes_cli.main import cmd_whatsapp
|
||||
|
||||
import argparse
|
||||
cmd_whatsapp(argparse.Namespace())
|
||||
|
||||
|
||||
@@ -701,10 +648,16 @@ def _is_service_installed() -> bool:
|
||||
def _is_service_running() -> bool:
|
||||
"""Check if the gateway service is currently running."""
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
result = subprocess.run(["systemctl", "--user", "is-active", SERVICE_NAME], capture_output=True, text=True)
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", SERVICE_NAME],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return result.stdout.strip() == "active"
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True)
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return result.returncode == 0
|
||||
# Check for manual processes
|
||||
return len(find_gateway_pids()) > 0
|
||||
@@ -739,7 +692,7 @@ def _setup_signal():
|
||||
print_info(" Docker: bbernhard/signal-cli-rest-api")
|
||||
print()
|
||||
print_info(" After installing, link your account and start the daemon:")
|
||||
print_info(' signal-cli link -n "HermesAgent"')
|
||||
print_info(" signal-cli link -n \"HermesAgent\"")
|
||||
print_info(" signal-cli --account +YOURNUMBER daemon --http 127.0.0.1:8080")
|
||||
print()
|
||||
|
||||
@@ -757,7 +710,6 @@ def _setup_signal():
|
||||
print_info(" Testing connection...")
|
||||
try:
|
||||
import httpx
|
||||
|
||||
resp = httpx.get(f"{url.rstrip('/')}/api/v1/check", timeout=10.0)
|
||||
if resp.status_code == 200:
|
||||
print_success(" signal-cli daemon is reachable!")
|
||||
@@ -822,7 +774,7 @@ def _setup_signal():
|
||||
print_success("Signal configured!")
|
||||
print_info(f" URL: {url}")
|
||||
print_info(f" Account: {account}")
|
||||
print_info(" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing")
|
||||
print_info(f" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing")
|
||||
print_info(f" Groups: {'enabled' if get_env_value('SIGNAL_GROUP_ALLOWED_USERS') else 'disabled'}")
|
||||
|
||||
|
||||
@@ -884,10 +836,11 @@ def gateway_setup():
|
||||
_setup_standard_platform(platform)
|
||||
|
||||
# ── Post-setup: offer to install/restart gateway ──
|
||||
any_configured = (
|
||||
any(bool(get_env_value(p["token_var"])) for p in _PLATFORMS if p["key"] != "whatsapp")
|
||||
or (get_env_value("WHATSAPP_ENABLED") or "").lower() == "true"
|
||||
)
|
||||
any_configured = any(
|
||||
bool(get_env_value(p["token_var"]))
|
||||
for p in _PLATFORMS
|
||||
if p["key"] != "whatsapp"
|
||||
) or (get_env_value("WHATSAPP_ENABLED") or "").lower() == "true"
|
||||
|
||||
if any_configured:
|
||||
print()
|
||||
@@ -920,9 +873,7 @@ def gateway_setup():
|
||||
print()
|
||||
if is_linux() or is_macos():
|
||||
platform_name = "systemd" if is_linux() else "launchd"
|
||||
if prompt_yes_no(
|
||||
f" Install the gateway as a {platform_name} service? (runs in background, starts on boot)", True
|
||||
):
|
||||
if prompt_yes_no(f" Install the gateway as a {platform_name} service? (runs in background, starts on boot)", True):
|
||||
try:
|
||||
force = False
|
||||
if is_linux():
|
||||
@@ -958,15 +909,14 @@ def gateway_setup():
|
||||
# Main Command Handler
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def gateway_command(args):
|
||||
"""Handle gateway subcommands."""
|
||||
subcmd = getattr(args, "gateway_command", None)
|
||||
|
||||
subcmd = getattr(args, 'gateway_command', None)
|
||||
|
||||
# Default to run if no subcommand
|
||||
if subcmd is None or subcmd == "run":
|
||||
verbose = getattr(args, "verbose", False)
|
||||
replace = getattr(args, "replace", False)
|
||||
verbose = getattr(args, 'verbose', False)
|
||||
replace = getattr(args, 'replace', False)
|
||||
run_gateway(verbose, replace=replace)
|
||||
return
|
||||
|
||||
@@ -976,7 +926,7 @@ def gateway_command(args):
|
||||
|
||||
# Service management commands
|
||||
if subcmd == "install":
|
||||
force = getattr(args, "force", False)
|
||||
force = getattr(args, 'force', False)
|
||||
if is_linux():
|
||||
systemd_install(force)
|
||||
elif is_macos():
|
||||
@@ -985,7 +935,7 @@ def gateway_command(args):
|
||||
print("Service installation not supported on this platform.")
|
||||
print("Run manually: hermes gateway run")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
elif subcmd == "uninstall":
|
||||
if is_linux():
|
||||
systemd_uninstall()
|
||||
@@ -994,7 +944,7 @@ def gateway_command(args):
|
||||
else:
|
||||
print("Not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
elif subcmd == "start":
|
||||
if is_linux():
|
||||
systemd_start()
|
||||
@@ -1003,11 +953,11 @@ def gateway_command(args):
|
||||
else:
|
||||
print("Not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
elif subcmd == "stop":
|
||||
# Try service first, fall back to killing processes directly
|
||||
service_available = False
|
||||
|
||||
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
try:
|
||||
systemd_stop()
|
||||
@@ -1020,7 +970,7 @@ def gateway_command(args):
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
if not service_available:
|
||||
# Kill gateway processes directly
|
||||
killed = kill_gateway_processes()
|
||||
@@ -1028,11 +978,11 @@ def gateway_command(args):
|
||||
print(f"✓ Stopped {killed} gateway process(es)")
|
||||
else:
|
||||
print("✗ No gateway processes found")
|
||||
|
||||
|
||||
elif subcmd == "restart":
|
||||
# Try service first, fall back to killing and restarting
|
||||
service_available = False
|
||||
|
||||
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
try:
|
||||
systemd_restart()
|
||||
@@ -1045,24 +995,23 @@ def gateway_command(args):
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
if not service_available:
|
||||
# Manual restart: kill existing processes
|
||||
killed = kill_gateway_processes()
|
||||
if killed:
|
||||
print(f"✓ Stopped {killed} gateway process(es)")
|
||||
|
||||
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
# Start fresh
|
||||
print("Starting gateway...")
|
||||
run_gateway(verbose=False)
|
||||
|
||||
|
||||
elif subcmd == "status":
|
||||
deep = getattr(args, "deep", False)
|
||||
|
||||
deep = getattr(args, 'deep', False)
|
||||
|
||||
# Check for service first
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
systemd_status(deep)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,26 +8,26 @@ Add, remove, or reorder entries here — both `hermes setup` and
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from difflib import get_close_matches
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
("anthropic/claude-sonnet-4.5", ""),
|
||||
("openai/gpt-5.4-pro", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("minimax/minimax-m2.5", ""),
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
("anthropic/claude-sonnet-4.5", ""),
|
||||
("openai/gpt-5.4-pro", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("minimax/minimax-m2.5", ""),
|
||||
]
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
@@ -93,7 +93,9 @@ def menu_labels() -> list[str]:
|
||||
|
||||
# All provider IDs and aliases that are valid for the provider:model syntax.
|
||||
_KNOWN_PROVIDER_NAMES: set[str] = (
|
||||
set(_PROVIDER_LABELS.keys()) | set(_PROVIDER_ALIASES.keys()) | {"openrouter", "custom"}
|
||||
set(_PROVIDER_LABELS.keys())
|
||||
| set(_PROVIDER_ALIASES.keys())
|
||||
| {"openrouter", "custom"}
|
||||
)
|
||||
|
||||
|
||||
@@ -105,13 +107,8 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
"""
|
||||
# Canonical providers in display order
|
||||
_PROVIDER_ORDER = [
|
||||
"openrouter",
|
||||
"nous",
|
||||
"openai-codex",
|
||||
"zai",
|
||||
"kimi-coding",
|
||||
"minimax",
|
||||
"minimax-cn",
|
||||
"openrouter", "nous", "openai-codex",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn",
|
||||
]
|
||||
# Build reverse alias map
|
||||
aliases_for: dict[str, list[str]] = {}
|
||||
@@ -126,19 +123,16 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
|
||||
runtime = resolve_runtime_provider(requested=pid)
|
||||
has_creds = bool(runtime.get("api_key"))
|
||||
except Exception:
|
||||
pass
|
||||
result.append(
|
||||
{
|
||||
"id": pid,
|
||||
"label": label,
|
||||
"aliases": alias_list,
|
||||
"authenticated": has_creds,
|
||||
}
|
||||
)
|
||||
result.append({
|
||||
"id": pid,
|
||||
"label": label,
|
||||
"aliases": alias_list,
|
||||
"authenticated": has_creds,
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
@@ -163,13 +157,13 @@ def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]:
|
||||
colon = stripped.find(":")
|
||||
if colon > 0:
|
||||
provider_part = stripped[:colon].strip().lower()
|
||||
model_part = stripped[colon + 1 :].strip()
|
||||
model_part = stripped[colon + 1:].strip()
|
||||
if provider_part and model_part and provider_part in _KNOWN_PROVIDER_NAMES:
|
||||
return (normalize_provider(provider_part), model_part)
|
||||
return (current_provider, stripped)
|
||||
|
||||
|
||||
def curated_models_for_provider(provider: str | None) -> list[tuple[str, str]]:
|
||||
def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]:
|
||||
"""Return ``(model_id, description)`` tuples for a provider's curated list."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
@@ -178,7 +172,7 @@ def curated_models_for_provider(provider: str | None) -> list[tuple[str, str]]:
|
||||
return [(m, "") for m in models]
|
||||
|
||||
|
||||
def normalize_provider(provider: str | None) -> str:
|
||||
def normalize_provider(provider: Optional[str]) -> str:
|
||||
"""Normalize provider aliases to Hermes' canonical provider ids.
|
||||
|
||||
Note: ``"auto"`` passes through unchanged — use
|
||||
@@ -189,7 +183,7 @@ def normalize_provider(provider: str | None) -> str:
|
||||
return _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
|
||||
def provider_model_ids(provider: str | None) -> list[str]:
|
||||
def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
"""Return the best known model catalog for a provider."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
@@ -202,10 +196,10 @@ def provider_model_ids(provider: str | None) -> list[str]:
|
||||
|
||||
|
||||
def fetch_api_models(
|
||||
api_key: str | None,
|
||||
base_url: str | None,
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
timeout: float = 5.0,
|
||||
) -> list[str] | None:
|
||||
) -> Optional[list[str]]:
|
||||
"""Fetch the list of available model IDs from the provider's ``/models`` endpoint.
|
||||
|
||||
Returns a list of model ID strings, or ``None`` if the endpoint could not
|
||||
@@ -231,10 +225,10 @@ def fetch_api_models(
|
||||
|
||||
def validate_requested_model(
|
||||
model_name: str,
|
||||
provider: str | None,
|
||||
provider: Optional[str],
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate a ``/model`` value for the active provider.
|
||||
@@ -292,7 +286,10 @@ def validate_requested_model(
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (f"Error: `{requested}` is not a valid model for this provider.{suggestion_text}"),
|
||||
"message": (
|
||||
f"Error: `{requested}` is not a valid model for this provider."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
# api_models is None — couldn't reach API, fall back to catalog check
|
||||
|
||||
@@ -8,7 +8,6 @@ Usage:
|
||||
hermes pairing clear-pending # Clear all expired/pending codes
|
||||
"""
|
||||
|
||||
|
||||
def pairing_command(args):
|
||||
"""Handle hermes pairing subcommands."""
|
||||
from gateway.pairing import PairingStore
|
||||
@@ -73,10 +72,10 @@ def _cmd_approve(store, platform: str, code: str):
|
||||
name = result.get("user_name", "")
|
||||
display = f"{name} ({uid})" if name else uid
|
||||
print(f"\n Approved! User {display} on {platform} can now use the bot~")
|
||||
print(" They'll be recognized automatically on their next message.\n")
|
||||
print(f" They'll be recognized automatically on their next message.\n")
|
||||
else:
|
||||
print(f"\n Code '{code}' not found or expired for platform '{platform}'.")
|
||||
print(" Run 'hermes pairing list' to see pending codes.\n")
|
||||
print(f" Run 'hermes pairing list' to see pending codes.\n")
|
||||
|
||||
|
||||
def _cmd_revoke(store, platform: str, user_id: str):
|
||||
|
||||
@@ -3,22 +3,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY,
|
||||
AuthError,
|
||||
PROVIDER_REGISTRY,
|
||||
format_auth_error,
|
||||
resolve_api_key_provider_credentials,
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_nous_runtime_credentials,
|
||||
resolve_provider,
|
||||
resolve_nous_runtime_credentials,
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_api_key_provider_credentials,
|
||||
)
|
||||
from hermes_cli.config import load_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
|
||||
def _get_model_config() -> dict[str, Any]:
|
||||
def _get_model_config() -> Dict[str, Any]:
|
||||
config = load_config()
|
||||
model_cfg = config.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
@@ -28,7 +28,7 @@ def _get_model_config() -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def resolve_requested_provider(requested: str | None = None) -> str:
|
||||
def resolve_requested_provider(requested: Optional[str] = None) -> str:
|
||||
"""Resolve provider request from explicit arg, env, then config."""
|
||||
if requested and requested.strip():
|
||||
return requested.strip().lower()
|
||||
@@ -48,9 +48,9 @@ def resolve_requested_provider(requested: str | None = None) -> str:
|
||||
def _resolve_openrouter_runtime(
|
||||
*,
|
||||
requested_provider: str,
|
||||
explicit_api_key: str | None = None,
|
||||
explicit_base_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
explicit_api_key: Optional[str] = None,
|
||||
explicit_base_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
model_cfg = _get_model_config()
|
||||
cfg_base_url = model_cfg.get("base_url") if isinstance(model_cfg.get("base_url"), str) else ""
|
||||
cfg_provider = model_cfg.get("provider") if isinstance(model_cfg.get("provider"), str) else ""
|
||||
@@ -81,9 +81,19 @@ def _resolve_openrouter_runtime(
|
||||
# provider (issues #420, #560).
|
||||
_is_openrouter_url = "openrouter.ai" in base_url
|
||||
if _is_openrouter_url:
|
||||
api_key = explicit_api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
else:
|
||||
api_key = explicit_api_key or os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY") or ""
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or ""
|
||||
)
|
||||
|
||||
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
|
||||
|
||||
@@ -98,10 +108,10 @@ def _resolve_openrouter_runtime(
|
||||
|
||||
def resolve_runtime_provider(
|
||||
*,
|
||||
requested: str | None = None,
|
||||
explicit_api_key: str | None = None,
|
||||
explicit_base_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
requested: Optional[str] = None,
|
||||
explicit_api_key: Optional[str] = None,
|
||||
explicit_base_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Resolve runtime provider credentials for agent execution."""
|
||||
requested_provider = resolve_requested_provider(requested)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@ handler are thin wrappers that parse args and delegate.
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -28,7 +29,6 @@ _console = Console()
|
||||
# Shared do_* functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_short_name(name: str, sources, console: Console) -> str:
|
||||
"""
|
||||
Resolve a short skill name (e.g. 'pptx') to a full identifier by searching
|
||||
@@ -57,9 +57,7 @@ def _resolve_short_name(name: str, sources, console: Console) -> str:
|
||||
table.add_column("Trust", style="dim")
|
||||
table.add_column("Identifier", style="bold cyan")
|
||||
for r in exact:
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(
|
||||
r.trust_level, "dim"
|
||||
)
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
|
||||
trust_label = "official" if r.source == "official" else r.trust_level
|
||||
table.add_row(r.source, f"[{trust_style}]{trust_label}[/]", r.identifier)
|
||||
c.print(table)
|
||||
@@ -78,7 +76,8 @@ def _resolve_short_name(name: str, sources, console: Console) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def do_search(query: str, source: str = "all", limit: int = 10, console: Console | None = None) -> None:
|
||||
def do_search(query: str, source: str = "all", limit: int = 10,
|
||||
console: Optional[Console] = None) -> None:
|
||||
"""Search registries and display results as a Rich table."""
|
||||
from tools.skills_hub import GitHubAuth, create_source_router, unified_search
|
||||
|
||||
@@ -112,19 +111,18 @@ def do_search(query: str, source: str = "all", limit: int = 10, console: Console
|
||||
)
|
||||
|
||||
c.print(table)
|
||||
c.print(
|
||||
"[dim]Use: hermes skills inspect <identifier> to preview, hermes skills install <identifier> to install[/]\n"
|
||||
)
|
||||
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
|
||||
"hermes skills install <identifier> to install[/]\n")
|
||||
|
||||
|
||||
def do_browse(page: int = 1, page_size: int = 20, source: str = "all", console: Console | None = None) -> None:
|
||||
def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
console: Optional[Console] = None) -> None:
|
||||
"""Browse all available skills across registries, paginated.
|
||||
|
||||
Official skills are always shown first, regardless of source filter.
|
||||
"""
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth,
|
||||
create_source_router,
|
||||
GitHubAuth, create_source_router, OptionalSkillSource, SkillMeta,
|
||||
)
|
||||
|
||||
# Clamp page_size to safe range
|
||||
@@ -138,7 +136,8 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", console:
|
||||
# Collect results from all (or filtered) sources
|
||||
# Use empty query to get everything; per-source limits prevent overload
|
||||
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50, "claude-marketplace": 50, "lobehub": 50}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50,
|
||||
"claude-marketplace": 50, "lobehub": 50}
|
||||
|
||||
all_results: list = []
|
||||
source_counts: dict = {}
|
||||
@@ -169,13 +168,11 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", console:
|
||||
deduped = list(seen.values())
|
||||
|
||||
# Sort: official first, then by trust level (desc), then alphabetically
|
||||
deduped.sort(
|
||||
key=lambda r: (
|
||||
-_TRUST_RANK.get(r.trust_level, 0),
|
||||
r.source != "official",
|
||||
r.name.lower(),
|
||||
)
|
||||
)
|
||||
deduped.sort(key=lambda r: (
|
||||
-_TRUST_RANK.get(r.trust_level, 0),
|
||||
r.source != "official",
|
||||
r.name.lower(),
|
||||
))
|
||||
|
||||
# Paginate
|
||||
total = len(deduped)
|
||||
@@ -190,7 +187,8 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", console:
|
||||
|
||||
# Build header
|
||||
source_label = f"— {source}" if source != "all" else "— all sources"
|
||||
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/] [dim]({total} skills, page {page}/{total_pages})[/]")
|
||||
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]"
|
||||
f" [dim]({total} skills, page {page}/{total_pages})[/]")
|
||||
if official_count > 0 and page == 1:
|
||||
c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]")
|
||||
c.print()
|
||||
@@ -204,7 +202,8 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", console:
|
||||
table.add_column("Trust", width=10)
|
||||
|
||||
for i, r in enumerate(page_items, start=start + 1):
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green",
|
||||
"community": "yellow"}.get(r.trust_level, "dim")
|
||||
trust_label = "★ official" if r.source == "official" else r.trust_level
|
||||
|
||||
desc = r.description[:50]
|
||||
@@ -236,22 +235,18 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", console:
|
||||
parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())]
|
||||
c.print(f" [dim]Sources: {', '.join(parts)}[/]")
|
||||
|
||||
c.print(
|
||||
"[dim]Use: hermes skills inspect <identifier> to preview, hermes skills install <identifier> to install[/]\n"
|
||||
)
|
||||
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
|
||||
"hermes skills install <identifier> to install[/]\n")
|
||||
|
||||
|
||||
def do_install(identifier: str, category: str = "", force: bool = False, console: Console | None = None) -> None:
|
||||
def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
console: Optional[Console] = None) -> None:
|
||||
"""Fetch, quarantine, scan, confirm, and install a skill."""
|
||||
from tools.skills_guard import format_scan_report, scan_skill, should_allow_install
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth,
|
||||
HubLockFile,
|
||||
create_source_router,
|
||||
ensure_hub_dirs,
|
||||
install_from_quarantine,
|
||||
quarantine_bundle,
|
||||
GitHubAuth, create_source_router, ensure_hub_dirs,
|
||||
quarantine_bundle, install_from_quarantine, HubLockFile,
|
||||
)
|
||||
from tools.skills_guard import scan_skill, should_allow_install, format_scan_report
|
||||
|
||||
c = console or _console
|
||||
ensure_hub_dirs()
|
||||
@@ -309,43 +304,33 @@ def do_install(identifier: str, category: str = "", force: bool = False, console
|
||||
# Clean up quarantine
|
||||
shutil.rmtree(q_path, ignore_errors=True)
|
||||
from tools.skills_hub import append_audit_log
|
||||
|
||||
append_audit_log(
|
||||
"BLOCKED",
|
||||
bundle.name,
|
||||
bundle.source,
|
||||
bundle.trust_level,
|
||||
result.verdict,
|
||||
f"{len(result.findings)}_findings",
|
||||
)
|
||||
append_audit_log("BLOCKED", bundle.name, bundle.source,
|
||||
bundle.trust_level, result.verdict,
|
||||
f"{len(result.findings)}_findings")
|
||||
return
|
||||
|
||||
# Confirm with user — show appropriate warning based on source
|
||||
if not force:
|
||||
c.print()
|
||||
if bundle.source == "official":
|
||||
c.print(
|
||||
Panel(
|
||||
"[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n"
|
||||
"It ships with hermes-agent but is not activated by default.\n"
|
||||
"Installing will copy it to your skills directory where the agent can use it.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Official Skill",
|
||||
border_style="bright_cyan",
|
||||
)
|
||||
)
|
||||
c.print(Panel(
|
||||
"[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n"
|
||||
"It ships with hermes-agent but is not activated by default.\n"
|
||||
"Installing will copy it to your skills directory where the agent can use it.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Official Skill",
|
||||
border_style="bright_cyan",
|
||||
))
|
||||
else:
|
||||
c.print(
|
||||
Panel(
|
||||
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
|
||||
"External skills can contain instructions that influence agent behavior,\n"
|
||||
"shell commands, and scripts. Even after automated scanning, you should\n"
|
||||
"review the installed files before use.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Disclaimer",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
c.print(Panel(
|
||||
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
|
||||
"External skills can contain instructions that influence agent behavior,\n"
|
||||
"shell commands, and scripts. Even after automated scanning, you should\n"
|
||||
"review the installed files before use.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Disclaimer",
|
||||
border_style="yellow",
|
||||
))
|
||||
c.print(f"[bold]Install '{bundle.name}'?[/]")
|
||||
try:
|
||||
answer = input("Confirm [y/N]: ").strip().lower()
|
||||
@@ -359,12 +344,11 @@ def do_install(identifier: str, category: str = "", force: bool = False, console
|
||||
# Install
|
||||
install_dir = install_from_quarantine(q_path, bundle.name, category, bundle, result)
|
||||
from tools.skills_hub import SKILLS_DIR
|
||||
|
||||
c.print(f"[bold green]Installed:[/] {install_dir.relative_to(SKILLS_DIR)}")
|
||||
c.print(f"[dim]Files: {', '.join(bundle.files.keys())}[/]\n")
|
||||
|
||||
|
||||
def do_inspect(identifier: str, console: Console | None = None) -> None:
|
||||
def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
"""Preview a skill's SKILL.md content without installing."""
|
||||
from tools.skills_hub import GitHubAuth, create_source_router
|
||||
|
||||
@@ -422,7 +406,7 @@ def do_inspect(identifier: str, console: Console | None = None) -> None:
|
||||
c.print()
|
||||
|
||||
|
||||
def do_list(source_filter: str = "all", console: Console | None = None) -> None:
|
||||
def do_list(source_filter: str = "all", console: Optional[Console] = None) -> None:
|
||||
"""List installed skills, distinguishing builtins from hub-installed."""
|
||||
from tools.skills_hub import HubLockFile, ensure_hub_dirs
|
||||
from tools.skills_tool import _find_all_skills
|
||||
@@ -462,13 +446,14 @@ def do_list(source_filter: str = "all", console: Console | None = None) -> None:
|
||||
table.add_row(name, category, source_display, f"[{trust_style}]{trust_label}[/]")
|
||||
|
||||
c.print(table)
|
||||
c.print(f"[dim]{len(hub_installed)} hub-installed, {len(all_skills) - len(hub_installed)} builtin[/]\n")
|
||||
c.print(f"[dim]{len(hub_installed)} hub-installed, "
|
||||
f"{len(all_skills) - len(hub_installed)} builtin[/]\n")
|
||||
|
||||
|
||||
def do_audit(name: str | None = None, console: Console | None = None) -> None:
|
||||
def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> None:
|
||||
"""Re-run security scan on installed hub skills."""
|
||||
from tools.skills_guard import format_scan_report, scan_skill
|
||||
from tools.skills_hub import SKILLS_DIR, HubLockFile
|
||||
from tools.skills_hub import HubLockFile, SKILLS_DIR
|
||||
from tools.skills_guard import scan_skill, format_scan_report
|
||||
|
||||
c = console or _console
|
||||
lock = HubLockFile()
|
||||
@@ -498,7 +483,7 @@ def do_audit(name: str | None = None, console: Console | None = None) -> None:
|
||||
c.print()
|
||||
|
||||
|
||||
def do_uninstall(name: str, console: Console | None = None) -> None:
|
||||
def do_uninstall(name: str, console: Optional[Console] = None) -> None:
|
||||
"""Remove a hub-installed skill with confirmation."""
|
||||
from tools.skills_hub import uninstall_skill
|
||||
|
||||
@@ -520,7 +505,7 @@ def do_uninstall(name: str, console: Console | None = None) -> None:
|
||||
c.print(f"[bold red]Error:[/] {msg}\n")
|
||||
|
||||
|
||||
def do_tap(action: str, repo: str = "", console: Console | None = None) -> None:
|
||||
def do_tap(action: str, repo: str = "", console: Optional[Console] = None) -> None:
|
||||
"""Manage taps (custom GitHub repo sources)."""
|
||||
from tools.skills_hub import TapsManager
|
||||
|
||||
@@ -562,10 +547,11 @@ def do_tap(action: str, repo: str = "", console: Console | None = None) -> None:
|
||||
c.print(f"[bold red]Unknown tap action:[/] {action}. Use: list, add, remove\n")
|
||||
|
||||
|
||||
def do_publish(skill_path: str, target: str = "github", repo: str = "", console: Console | None = None) -> None:
|
||||
def do_publish(skill_path: str, target: str = "github", repo: str = "",
|
||||
console: Optional[Console] = None) -> None:
|
||||
"""Publish a local skill to a registry (GitHub PR or ClawHub submission)."""
|
||||
from tools.skills_guard import format_scan_report, scan_skill
|
||||
from tools.skills_hub import SKILLS_DIR, GitHubAuth
|
||||
from tools.skills_hub import GitHubAuth, SKILLS_DIR
|
||||
from tools.skills_guard import scan_skill, format_scan_report
|
||||
|
||||
c = console or _console
|
||||
path = Path(skill_path)
|
||||
@@ -579,16 +565,14 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "", console:
|
||||
|
||||
# Validate the skill
|
||||
import yaml
|
||||
|
||||
skill_md = (path / "SKILL.md").read_text(encoding="utf-8")
|
||||
fm = {}
|
||||
if skill_md.startswith("---"):
|
||||
import re
|
||||
|
||||
match = re.search(r"\n---\s*\n", skill_md[3:])
|
||||
match = re.search(r'\n---\s*\n', skill_md[3:])
|
||||
if match:
|
||||
try:
|
||||
fm = yaml.safe_load(skill_md[3 : match.start() + 3]) or {}
|
||||
fm = yaml.safe_load(skill_md[3:match.start() + 3]) or {}
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
|
||||
@@ -608,18 +592,14 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "", console:
|
||||
|
||||
if target == "github":
|
||||
if not repo:
|
||||
c.print(
|
||||
"[bold red]Error:[/] --repo required for GitHub publish.\n"
|
||||
"Usage: hermes skills publish <path> --to github --repo owner/repo\n"
|
||||
)
|
||||
c.print("[bold red]Error:[/] --repo required for GitHub publish.\n"
|
||||
"Usage: hermes skills publish <path> --to github --repo owner/repo\n")
|
||||
return
|
||||
|
||||
auth = GitHubAuth()
|
||||
if not auth.is_authenticated():
|
||||
c.print(
|
||||
"[bold red]Error:[/] GitHub authentication required.\n"
|
||||
"Set GITHUB_TOKEN in ~/.hermes/.env or run 'gh auth login'.\n"
|
||||
)
|
||||
c.print("[bold red]Error:[/] GitHub authentication required.\n"
|
||||
"Set GITHUB_TOKEN in ~/.hermes/.env or run 'gh auth login'.\n")
|
||||
return
|
||||
|
||||
c.print(f"[bold]Publishing '{name}' to {repo}...[/]")
|
||||
@@ -630,12 +610,14 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "", console:
|
||||
c.print(f"[bold red]Error:[/] {msg}\n")
|
||||
|
||||
elif target == "clawhub":
|
||||
c.print("[yellow]ClawHub publishing is not yet supported. Submit manually at https://clawhub.ai/submit[/]\n")
|
||||
c.print("[yellow]ClawHub publishing is not yet supported. "
|
||||
"Submit manually at https://clawhub.ai/submit[/]\n")
|
||||
else:
|
||||
c.print(f"[bold red]Unknown target:[/] {target}. Use 'github' or 'clawhub'.\n")
|
||||
|
||||
|
||||
def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -> tuple:
|
||||
def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
|
||||
auth) -> tuple:
|
||||
"""Create a PR to a GitHub repo with the skill. Returns (success, message)."""
|
||||
import httpx
|
||||
|
||||
@@ -645,8 +627,7 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"https://api.github.com/repos/{target_repo}/forks",
|
||||
headers=headers,
|
||||
timeout=30,
|
||||
headers=headers, timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 202):
|
||||
fork = resp.json()
|
||||
@@ -662,8 +643,7 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"https://api.github.com/repos/{target_repo}",
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
headers=headers, timeout=15,
|
||||
)
|
||||
default_branch = resp.json().get("default_branch", "main")
|
||||
except Exception:
|
||||
@@ -673,8 +653,7 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"https://api.github.com/repos/{fork_repo}/git/refs/heads/{default_branch}",
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
headers=headers, timeout=15,
|
||||
)
|
||||
base_sha = resp.json()["object"]["sha"]
|
||||
except Exception as e:
|
||||
@@ -685,8 +664,7 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -
|
||||
try:
|
||||
httpx.post(
|
||||
f"https://api.github.com/repos/{fork_repo}/git/refs",
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
headers=headers, timeout=15,
|
||||
json={"ref": f"refs/heads/{branch_name}", "sha": base_sha},
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -700,12 +678,10 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -
|
||||
upload_path = f"skills/{skill_name}/{rel}"
|
||||
try:
|
||||
import base64
|
||||
|
||||
content_b64 = base64.b64encode(f.read_bytes()).decode()
|
||||
httpx.put(
|
||||
f"https://api.github.com/repos/{fork_repo}/contents/{upload_path}",
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
headers=headers, timeout=15,
|
||||
json={
|
||||
"message": f"Add {skill_name} skill: {rel}",
|
||||
"content": content_b64,
|
||||
@@ -719,12 +695,11 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"https://api.github.com/repos/{target_repo}/pulls",
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
headers=headers, timeout=15,
|
||||
json={
|
||||
"title": f"Add skill: {skill_name}",
|
||||
"body": f"Submitting the `{skill_name}` skill via Hermes Skills Hub.\n\n"
|
||||
f"This skill was scanned by the Hermes Skills Guard before submission.",
|
||||
f"This skill was scanned by the Hermes Skills Guard before submission.",
|
||||
"head": f"{fork_repo.split('/')[0]}:{branch_name}",
|
||||
"base": default_branch,
|
||||
},
|
||||
@@ -738,7 +713,7 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -
|
||||
return False, f"Network error creating PR: {e}"
|
||||
|
||||
|
||||
def do_snapshot_export(output_path: str, console: Console | None = None) -> None:
|
||||
def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> None:
|
||||
"""Export current hub skill configuration to a portable JSON file."""
|
||||
from tools.skills_hub import HubLockFile, TapsManager
|
||||
|
||||
@@ -751,15 +726,16 @@ def do_snapshot_export(output_path: str, console: Console | None = None) -> None
|
||||
|
||||
snapshot = {
|
||||
"hermes_version": "0.1.0",
|
||||
"exported_at": __import__("datetime").datetime.now(__import__("datetime").timezone.utc).isoformat(),
|
||||
"exported_at": __import__("datetime").datetime.now(
|
||||
__import__("datetime").timezone.utc
|
||||
).isoformat(),
|
||||
"skills": [
|
||||
{
|
||||
"name": entry["name"],
|
||||
"source": entry.get("source", ""),
|
||||
"identifier": entry.get("identifier", ""),
|
||||
"category": str(Path(entry.get("install_path", "")).parent)
|
||||
if "/" in entry.get("install_path", "")
|
||||
else "",
|
||||
if "/" in entry.get("install_path", "") else "",
|
||||
}
|
||||
for entry in installed
|
||||
],
|
||||
@@ -772,7 +748,8 @@ def do_snapshot_export(output_path: str, console: Console | None = None) -> None
|
||||
c.print(f"[dim]{len(installed)} skill(s), {len(tap_list)} tap(s)[/]\n")
|
||||
|
||||
|
||||
def do_snapshot_import(input_path: str, force: bool = False, console: Console | None = None) -> None:
|
||||
def do_snapshot_import(input_path: str, force: bool = False,
|
||||
console: Optional[Console] = None) -> None:
|
||||
"""Re-install skills from a snapshot file."""
|
||||
from tools.skills_hub import TapsManager
|
||||
|
||||
@@ -822,7 +799,6 @@ def do_snapshot_import(input_path: str, force: bool = False, console: Console |
|
||||
# CLI argparse entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def skills_command(args) -> None:
|
||||
"""Router for `hermes skills <subcommand>` — called from hermes_cli/main.py."""
|
||||
action = getattr(args, "skills_action", None)
|
||||
@@ -863,9 +839,7 @@ def skills_command(args) -> None:
|
||||
return
|
||||
do_tap(tap_action, repo=repo)
|
||||
else:
|
||||
_console.print(
|
||||
"Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n"
|
||||
)
|
||||
_console.print("Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n")
|
||||
_console.print("Run 'hermes skills <command> --help' for details.\n")
|
||||
|
||||
|
||||
@@ -873,8 +847,7 @@ def skills_command(args) -> None:
|
||||
# Slash command entry point (/skills in chat)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def handle_skills_slash(cmd: str, console: Console | None = None) -> None:
|
||||
def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
|
||||
"""
|
||||
Parse and dispatch `/skills <subcommand> [args]` from the chat interface.
|
||||
|
||||
@@ -1035,19 +1008,17 @@ def handle_skills_slash(cmd: str, console: Console | None = None) -> None:
|
||||
|
||||
def _print_skills_help(console: Console) -> None:
|
||||
"""Print help for the /skills slash command."""
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold]Skills Hub Commands:[/]\n\n"
|
||||
" [cyan]browse[/] [--source official] Browse all available skills (paginated)\n"
|
||||
" [cyan]search[/] <query> Search registries for skills\n"
|
||||
" [cyan]install[/] <identifier> Install a skill (with security scan)\n"
|
||||
" [cyan]inspect[/] <identifier> Preview a skill without installing\n"
|
||||
" [cyan]list[/] [--source hub|builtin] List installed skills\n"
|
||||
" [cyan]audit[/] [name] Re-scan hub skills for security\n"
|
||||
" [cyan]uninstall[/] <name> Remove a hub-installed skill\n"
|
||||
" [cyan]publish[/] <path> --repo <r> Publish a skill to GitHub via PR\n"
|
||||
" [cyan]snapshot[/] export|import Export/import skill configurations\n"
|
||||
" [cyan]tap[/] list|add|remove Manage skill sources\n",
|
||||
title="/skills",
|
||||
)
|
||||
)
|
||||
console.print(Panel(
|
||||
"[bold]Skills Hub Commands:[/]\n\n"
|
||||
" [cyan]browse[/] [--source official] Browse all available skills (paginated)\n"
|
||||
" [cyan]search[/] <query> Search registries for skills\n"
|
||||
" [cyan]install[/] <identifier> Install a skill (with security scan)\n"
|
||||
" [cyan]inspect[/] <identifier> Preview a skill without installing\n"
|
||||
" [cyan]list[/] [--source hub|builtin] List installed skills\n"
|
||||
" [cyan]audit[/] [name] Re-scan hub skills for security\n"
|
||||
" [cyan]uninstall[/] <name> Remove a hub-installed skill\n"
|
||||
" [cyan]publish[/] <path> --repo <r> Publish a skill to GitHub via PR\n"
|
||||
" [cyan]snapshot[/] export|import Export/import skill configurations\n"
|
||||
" [cyan]tap[/] list|add|remove Manage skill sources\n",
|
||||
title="/skills",
|
||||
))
|
||||
|
||||
@@ -5,25 +5,21 @@ Shows the status of all Hermes Agent components.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
from datetime import UTC
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_cli.config import get_env_path, get_env_value
|
||||
from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
|
||||
def check_mark(ok: bool) -> str:
|
||||
if ok:
|
||||
return color("✓", Colors.GREEN)
|
||||
return color("✗", Colors.RED)
|
||||
|
||||
|
||||
def redact_key(key: str) -> str:
|
||||
"""Redact an API key for display."""
|
||||
if not key:
|
||||
@@ -37,8 +33,7 @@ def _format_iso_timestamp(value) -> str:
|
||||
"""Format ISO timestamps for status output, converting to local timezone."""
|
||||
if not value or not isinstance(value, str):
|
||||
return "(unknown)"
|
||||
from datetime import datetime
|
||||
|
||||
from datetime import datetime, timezone
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return "(unknown)"
|
||||
@@ -47,7 +42,7 @@ def _format_iso_timestamp(value) -> str:
|
||||
try:
|
||||
parsed = datetime.fromisoformat(text)
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=UTC)
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
except Exception:
|
||||
return value
|
||||
return parsed.astimezone().strftime("%Y-%m-%d %H:%M:%S %Z")
|
||||
@@ -55,14 +50,14 @@ def _format_iso_timestamp(value) -> str:
|
||||
|
||||
def show_status(args):
|
||||
"""Show status of all Hermes Agent components."""
|
||||
show_all = getattr(args, "all", False)
|
||||
deep = getattr(args, "deep", False)
|
||||
|
||||
show_all = getattr(args, 'all', False)
|
||||
deep = getattr(args, 'deep', False)
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ ⚕ Hermes Agent Status │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Environment
|
||||
# =========================================================================
|
||||
@@ -70,19 +65,19 @@ def show_status(args):
|
||||
print(color("◆ Environment", Colors.CYAN, Colors.BOLD))
|
||||
print(f" Project: {PROJECT_ROOT}")
|
||||
print(f" Python: {sys.version.split()[0]}")
|
||||
|
||||
|
||||
env_path = get_env_path()
|
||||
print(f" .env file: {check_mark(env_path.exists())} {'exists' if env_path.exists() else 'not found'}")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# API Keys
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ API Keys", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
keys = {
|
||||
"OpenRouter": "OPENROUTER_API_KEY",
|
||||
"Anthropic": "ANTHROPIC_API_KEY",
|
||||
"Anthropic": "ANTHROPIC_API_KEY",
|
||||
"OpenAI": "OPENAI_API_KEY",
|
||||
"Z.AI/GLM": "GLM_API_KEY",
|
||||
"Kimi": "KIMI_API_KEY",
|
||||
@@ -96,7 +91,7 @@ def show_status(args):
|
||||
"ElevenLabs": "ELEVENLABS_API_KEY",
|
||||
"GitHub": "GITHUB_TOKEN",
|
||||
}
|
||||
|
||||
|
||||
for name, env_var in keys.items():
|
||||
value = get_env_value(env_var) or ""
|
||||
has_key = bool(value)
|
||||
@@ -110,8 +105,7 @@ def show_status(args):
|
||||
print(color("◆ Auth Providers", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import get_codex_auth_status, get_nous_auth_status
|
||||
|
||||
from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status
|
||||
nous_status = get_nous_auth_status()
|
||||
codex_status = get_codex_auth_status()
|
||||
except Exception:
|
||||
@@ -154,10 +148,10 @@ def show_status(args):
|
||||
print(color("◆ API-Key Providers", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
apikey_providers = {
|
||||
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
"Kimi / Moonshot": ("KIMI_API_KEY",),
|
||||
"MiniMax": ("MINIMAX_API_KEY",),
|
||||
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
|
||||
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
"Kimi / Moonshot": ("KIMI_API_KEY",),
|
||||
"MiniMax": ("MINIMAX_API_KEY",),
|
||||
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
|
||||
}
|
||||
for pname, env_vars in apikey_providers.items():
|
||||
key_val = ""
|
||||
@@ -174,20 +168,19 @@ def show_status(args):
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Terminal Backend", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
terminal_env = os.getenv("TERMINAL_ENV", "")
|
||||
if not terminal_env:
|
||||
# Fall back to config file value when env var isn't set
|
||||
# (hermes status doesn't go through cli.py's config loading)
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
_cfg = load_config()
|
||||
terminal_env = _cfg.get("terminal", {}).get("backend", "local")
|
||||
except Exception:
|
||||
terminal_env = "local"
|
||||
print(f" Backend: {terminal_env}")
|
||||
|
||||
|
||||
if terminal_env == "ssh":
|
||||
ssh_host = os.getenv("TERMINAL_SSH_HOST", "")
|
||||
ssh_user = os.getenv("TERMINAL_SSH_USER", "")
|
||||
@@ -199,16 +192,16 @@ def show_status(args):
|
||||
elif terminal_env == "daytona":
|
||||
daytona_image = os.getenv("TERMINAL_DAYTONA_IMAGE", "nikolaik/python-nodejs:python3.11-nodejs20")
|
||||
print(f" Daytona Image: {daytona_image}")
|
||||
|
||||
|
||||
sudo_password = os.getenv("SUDO_PASSWORD", "")
|
||||
print(f" Sudo: {check_mark(bool(sudo_password))} {'enabled' if sudo_password else 'disabled'}")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Messaging Platforms
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Messaging Platforms", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
platforms = {
|
||||
"Telegram": ("TELEGRAM_BOT_TOKEN", "TELEGRAM_HOME_CHANNEL"),
|
||||
"Discord": ("DISCORD_BOT_TOKEN", "DISCORD_HOME_CHANNEL"),
|
||||
@@ -216,52 +209,59 @@ def show_status(args):
|
||||
"Signal": ("SIGNAL_HTTP_URL", "SIGNAL_HOME_CHANNEL"),
|
||||
"Slack": ("SLACK_BOT_TOKEN", None),
|
||||
}
|
||||
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
token = os.getenv(token_var, "")
|
||||
has_token = bool(token)
|
||||
|
||||
|
||||
home_channel = ""
|
||||
if home_var:
|
||||
home_channel = os.getenv(home_var, "")
|
||||
|
||||
|
||||
status = "configured" if has_token else "not configured"
|
||||
if home_channel:
|
||||
status += f" (home: {home_channel})"
|
||||
|
||||
|
||||
print(f" {name:<12} {check_mark(has_token)} {status}")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Gateway Status
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
if sys.platform.startswith("linux"):
|
||||
result = subprocess.run(["systemctl", "--user", "is-active", "hermes-gateway"], capture_output=True, text=True)
|
||||
|
||||
if sys.platform.startswith('linux'):
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
is_active = result.stdout.strip() == "active"
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(" Manager: systemd (user)")
|
||||
|
||||
elif sys.platform == "darwin":
|
||||
result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True)
|
||||
print(f" Manager: systemd (user)")
|
||||
|
||||
elif sys.platform == 'darwin':
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
is_loaded = result.returncode == 0
|
||||
print(f" Status: {check_mark(is_loaded)} {'loaded' if is_loaded else 'not loaded'}")
|
||||
print(" Manager: launchd")
|
||||
print(f" Manager: launchd")
|
||||
else:
|
||||
print(f" Status: {color('N/A', Colors.DIM)}")
|
||||
print(" Manager: (not supported on this platform)")
|
||||
|
||||
print(f" Manager: (not supported on this platform)")
|
||||
|
||||
# =========================================================================
|
||||
# Cron Jobs
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Scheduled Jobs", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
jobs_file = Path.home() / ".hermes" / "cron" / "jobs.json"
|
||||
if jobs_file.exists():
|
||||
import json
|
||||
|
||||
try:
|
||||
with open(jobs_file) as f:
|
||||
data = json.load(f)
|
||||
@@ -269,57 +269,56 @@ def show_status(args):
|
||||
enabled_jobs = [j for j in jobs if j.get("enabled", True)]
|
||||
print(f" Jobs: {len(enabled_jobs)} active, {len(jobs)} total")
|
||||
except Exception:
|
||||
print(" Jobs: (error reading jobs file)")
|
||||
print(f" Jobs: (error reading jobs file)")
|
||||
else:
|
||||
print(" Jobs: 0")
|
||||
|
||||
print(f" Jobs: 0")
|
||||
|
||||
# =========================================================================
|
||||
# Sessions
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Sessions", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
sessions_file = Path.home() / ".hermes" / "sessions" / "sessions.json"
|
||||
if sessions_file.exists():
|
||||
import json
|
||||
|
||||
try:
|
||||
with open(sessions_file) as f:
|
||||
data = json.load(f)
|
||||
print(f" Active: {len(data)} session(s)")
|
||||
except Exception:
|
||||
print(" Active: (error reading sessions file)")
|
||||
print(f" Active: (error reading sessions file)")
|
||||
else:
|
||||
print(" Active: 0")
|
||||
|
||||
print(f" Active: 0")
|
||||
|
||||
# =========================================================================
|
||||
# Deep checks
|
||||
# =========================================================================
|
||||
if deep:
|
||||
print()
|
||||
print(color("◆ Deep Checks", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
# Check OpenRouter connectivity
|
||||
openrouter_key = os.getenv("OPENROUTER_API_KEY", "")
|
||||
if openrouter_key:
|
||||
try:
|
||||
import httpx
|
||||
|
||||
response = httpx.get(
|
||||
OPENROUTER_MODELS_URL, headers={"Authorization": f"Bearer {openrouter_key}"}, timeout=10
|
||||
OPENROUTER_MODELS_URL,
|
||||
headers={"Authorization": f"Bearer {openrouter_key}"},
|
||||
timeout=10
|
||||
)
|
||||
ok = response.status_code == 200
|
||||
print(f" OpenRouter: {check_mark(ok)} {'reachable' if ok else f'error ({response.status_code})'}")
|
||||
except Exception as e:
|
||||
print(f" OpenRouter: {check_mark(False)} error: {e}")
|
||||
|
||||
|
||||
# Check gateway port
|
||||
try:
|
||||
import socket
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1)
|
||||
result = sock.connect_ex(("127.0.0.1", 18789))
|
||||
result = sock.connect_ex(('127.0.0.1', 18789))
|
||||
sock.close()
|
||||
# Port in use = gateway likely running
|
||||
port_in_use = result == 0
|
||||
@@ -327,7 +326,7 @@ def show_status(args):
|
||||
print(f" Port 18789: {'in use' if port_in_use else 'available'}")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
print()
|
||||
print(color("─" * 60, Colors.DIM))
|
||||
print(color(" Run 'hermes doctor' for detailed diagnostics", Colors.DIM))
|
||||
|
||||
@@ -11,37 +11,33 @@ the `platform_toolsets` key.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import os
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_cli.config import (
|
||||
get_env_value,
|
||||
load_config,
|
||||
save_config,
|
||||
save_env_value,
|
||||
load_config, save_config, get_env_value, save_env_value,
|
||||
get_hermes_home,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
|
||||
|
||||
|
||||
def _print_info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
|
||||
def _print_success(text: str):
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
|
||||
def _print_warning(text: str):
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
def _print_error(text: str):
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
|
||||
def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
if default:
|
||||
display = f"{question} [{default}]: "
|
||||
@@ -50,7 +46,6 @@ def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
try:
|
||||
if password:
|
||||
import getpass
|
||||
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
@@ -59,7 +54,6 @@ def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
print()
|
||||
return default or ""
|
||||
|
||||
|
||||
def _prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
default_str = "Y/n" if default else "y/N"
|
||||
while True:
|
||||
@@ -70,9 +64,9 @@ def _prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
return default
|
||||
if not value:
|
||||
return default
|
||||
if value in ("y", "yes"):
|
||||
if value in ('y', 'yes'):
|
||||
return True
|
||||
if value in ("n", "no"):
|
||||
if value in ('n', 'no'):
|
||||
return False
|
||||
|
||||
|
||||
@@ -82,24 +76,24 @@ def _prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
# Each entry: (toolset_name, label, description)
|
||||
# These map to keys in toolsets.py TOOLSETS dict.
|
||||
CONFIGURABLE_TOOLSETS = [
|
||||
("web", "🔍 Web Search & Scraping", "web_search, web_extract"),
|
||||
("browser", "🌐 Browser Automation", "navigate, click, type, scroll"),
|
||||
("terminal", "💻 Terminal & Processes", "terminal, process"),
|
||||
("file", "📁 File Operations", "read, write, patch, search"),
|
||||
("code_execution", "⚡ Code Execution", "execute_code"),
|
||||
("vision", "👁️ Vision / Image Analysis", "vision_analyze"),
|
||||
("image_gen", "🎨 Image Generation", "image_generate"),
|
||||
("moa", "🧠 Mixture of Agents", "mixture_of_agents"),
|
||||
("tts", "🔊 Text-to-Speech", "text_to_speech"),
|
||||
("skills", "📚 Skills", "list, view, manage"),
|
||||
("todo", "📋 Task Planning", "todo"),
|
||||
("memory", "💾 Memory", "persistent memory across sessions"),
|
||||
("session_search", "🔎 Session Search", "search past conversations"),
|
||||
("clarify", "❓ Clarifying Questions", "clarify"),
|
||||
("delegation", "👥 Task Delegation", "delegate_task"),
|
||||
("cronjob", "⏰ Cron Jobs", "schedule, list, remove"),
|
||||
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
|
||||
("homeassistant", "🏠 Home Assistant", "smart home device control"),
|
||||
("web", "🔍 Web Search & Scraping", "web_search, web_extract"),
|
||||
("browser", "🌐 Browser Automation", "navigate, click, type, scroll"),
|
||||
("terminal", "💻 Terminal & Processes", "terminal, process"),
|
||||
("file", "📁 File Operations", "read, write, patch, search"),
|
||||
("code_execution", "⚡ Code Execution", "execute_code"),
|
||||
("vision", "👁️ Vision / Image Analysis", "vision_analyze"),
|
||||
("image_gen", "🎨 Image Generation", "image_generate"),
|
||||
("moa", "🧠 Mixture of Agents", "mixture_of_agents"),
|
||||
("tts", "🔊 Text-to-Speech", "text_to_speech"),
|
||||
("skills", "📚 Skills", "list, view, manage"),
|
||||
("todo", "📋 Task Planning", "todo"),
|
||||
("memory", "💾 Memory", "persistent memory across sessions"),
|
||||
("session_search", "🔎 Session Search", "search past conversations"),
|
||||
("clarify", "❓ Clarifying Questions", "clarify"),
|
||||
("delegation", "👥 Task Delegation", "delegate_task"),
|
||||
("cronjob", "⏰ Cron Jobs", "schedule, list, remove"),
|
||||
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
|
||||
("homeassistant", "🏠 Home Assistant", "smart home device control"),
|
||||
]
|
||||
|
||||
# Toolsets that are OFF by default for new installs.
|
||||
@@ -109,11 +103,11 @@ _DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "rl"}
|
||||
|
||||
# Platform display config
|
||||
PLATFORMS = {
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
|
||||
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
|
||||
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
|
||||
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
|
||||
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
}
|
||||
|
||||
|
||||
@@ -137,11 +131,7 @@ TOOL_CATEGORIES = {
|
||||
"name": "OpenAI TTS",
|
||||
"tag": "Premium - high quality voices",
|
||||
"env_vars": [
|
||||
{
|
||||
"key": "VOICE_TOOLS_OPENAI_KEY",
|
||||
"prompt": "OpenAI API key",
|
||||
"url": "https://platform.openai.com/api-keys",
|
||||
},
|
||||
{"key": "VOICE_TOOLS_OPENAI_KEY", "prompt": "OpenAI API key", "url": "https://platform.openai.com/api-keys"},
|
||||
],
|
||||
"tts_provider": "openai",
|
||||
},
|
||||
@@ -149,11 +139,7 @@ TOOL_CATEGORIES = {
|
||||
"name": "ElevenLabs",
|
||||
"tag": "Premium - most natural voices",
|
||||
"env_vars": [
|
||||
{
|
||||
"key": "ELEVENLABS_API_KEY",
|
||||
"prompt": "ElevenLabs API key",
|
||||
"url": "https://elevenlabs.io/app/settings/api-keys",
|
||||
},
|
||||
{"key": "ELEVENLABS_API_KEY", "prompt": "ElevenLabs API key", "url": "https://elevenlabs.io/app/settings/api-keys"},
|
||||
],
|
||||
"tts_provider": "elevenlabs",
|
||||
},
|
||||
@@ -238,11 +224,7 @@ TOOL_CATEGORIES = {
|
||||
"name": "Tinker / Atropos",
|
||||
"tag": "RL training platform",
|
||||
"env_vars": [
|
||||
{
|
||||
"key": "TINKER_API_KEY",
|
||||
"prompt": "Tinker API key",
|
||||
"url": "https://tinker-console.thinkingmachines.ai/keys",
|
||||
},
|
||||
{"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"},
|
||||
{"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"},
|
||||
],
|
||||
"post_setup": "rl_training",
|
||||
@@ -254,26 +236,24 @@ TOOL_CATEGORIES = {
|
||||
# Simple env-var requirements for toolsets NOT in TOOL_CATEGORIES.
|
||||
# Used as a fallback for tools like vision/moa that just need an API key.
|
||||
TOOLSET_ENV_REQUIREMENTS = {
|
||||
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
}
|
||||
|
||||
|
||||
# ─── Post-Setup Hooks ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _run_post_setup(post_setup_key: str):
|
||||
"""Run post-setup hooks for tools that need extra installation steps."""
|
||||
import shutil
|
||||
|
||||
if post_setup_key == "browserbase":
|
||||
node_modules = PROJECT_ROOT / "node_modules" / "agent-browser"
|
||||
if not node_modules.exists() and shutil.which("npm"):
|
||||
_print_info(" Installing Node.js dependencies for browser tools...")
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--silent"], capture_output=True, text=True, cwd=str(PROJECT_ROOT)
|
||||
["npm", "install", "--silent"],
|
||||
capture_output=True, text=True, cwd=str(PROJECT_ROOT)
|
||||
)
|
||||
if result.returncode == 0:
|
||||
_print_success(" Node.js dependencies installed")
|
||||
@@ -290,17 +270,16 @@ def _run_post_setup(post_setup_key: str):
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
_print_info(" Installing tinker-atropos submodule...")
|
||||
import subprocess
|
||||
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin:
|
||||
result = subprocess.run(
|
||||
[uv_bin, "pip", "install", "--python", sys.executable, "-e", str(tinker_dir)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
capture_output=True, text=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)], capture_output=True, text=True
|
||||
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
_print_success(" tinker-atropos installed")
|
||||
@@ -315,8 +294,7 @@ def _run_post_setup(post_setup_key: str):
|
||||
|
||||
# ─── Platform / Toolset Helpers ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_enabled_platforms() -> list[str]:
|
||||
def _get_enabled_platforms() -> List[str]:
|
||||
"""Return platform keys that are configured (have tokens or are CLI)."""
|
||||
enabled = ["cli"]
|
||||
if get_env_value("TELEGRAM_BOT_TOKEN"):
|
||||
@@ -330,9 +308,9 @@ def _get_enabled_platforms() -> list[str]:
|
||||
return enabled
|
||||
|
||||
|
||||
def _get_platform_tools(config: dict, platform: str) -> set[str]:
|
||||
def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
"""Resolve which individual toolset names are enabled for a platform."""
|
||||
from toolsets import resolve_toolset
|
||||
from toolsets import resolve_toolset, TOOLSETS
|
||||
|
||||
platform_toolsets = config.get("platform_toolsets", {})
|
||||
toolset_names = platform_toolsets.get(platform)
|
||||
@@ -357,7 +335,7 @@ def _get_platform_tools(config: dict, platform: str) -> set[str]:
|
||||
return enabled_toolsets
|
||||
|
||||
|
||||
def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: set[str]):
|
||||
def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[str]):
|
||||
"""Save the selected toolset keys for a platform to config."""
|
||||
config.setdefault("platform_toolsets", {})
|
||||
config["platform_toolsets"][platform] = sorted(enabled_toolset_keys)
|
||||
@@ -386,7 +364,6 @@ def _toolset_has_keys(ts_key: str) -> bool:
|
||||
|
||||
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
|
||||
rendering bugs in tmux, iTerm, and other non-standard terminals."""
|
||||
@@ -394,7 +371,6 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
# Curses-based single-select — works in tmux, iTerm, and standard terminals
|
||||
try:
|
||||
import curses
|
||||
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
@@ -410,9 +386,8 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
try:
|
||||
stdscr.addnstr(
|
||||
0, 0, question, max_x - 1, curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)
|
||||
)
|
||||
stdscr.addnstr(0, 0, question, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
@@ -435,14 +410,14 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord("q")):
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
@@ -456,7 +431,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {i + 1}. {c}", style) if style else f" {marker} {i + 1}. {c}")
|
||||
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
@@ -470,7 +445,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
return default
|
||||
|
||||
|
||||
def _prompt_toolset_checklist(platform_label: str, enabled: set[str]) -> set[str]:
|
||||
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
|
||||
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
|
||||
|
||||
labels = []
|
||||
@@ -480,13 +455,15 @@ def _prompt_toolset_checklist(platform_label: str, enabled: set[str]) -> set[str
|
||||
suffix = " [no API key]"
|
||||
labels.append(f"{ts_label} ({ts_desc}){suffix}")
|
||||
|
||||
pre_selected_indices = [i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS) if ts_key in enabled]
|
||||
pre_selected_indices = [
|
||||
i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)
|
||||
if ts_key in enabled
|
||||
]
|
||||
|
||||
# Curses-based multi-select — arrow keys + space to toggle + enter to confirm.
|
||||
# simple_term_menu has rendering bugs in tmux, iTerm, and other terminals.
|
||||
try:
|
||||
import curses
|
||||
|
||||
selected = set(pre_selected_indices)
|
||||
result_holder = [None]
|
||||
|
||||
@@ -506,13 +483,7 @@ def _prompt_toolset_checklist(platform_label: str, enabled: set[str]) -> set[str
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
header = f"Tools for {platform_label} — ↑↓ navigate, SPACE toggle, ENTER confirm"
|
||||
try:
|
||||
stdscr.addnstr(
|
||||
0,
|
||||
0,
|
||||
header,
|
||||
max_x - 1,
|
||||
curses.A_BOLD | curses.color_pair(2) if curses.has_colors() else curses.A_BOLD,
|
||||
)
|
||||
stdscr.addnstr(0, 0, header, max_x - 1, curses.A_BOLD | curses.color_pair(2) if curses.has_colors() else curses.A_BOLD)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
@@ -543,11 +514,11 @@ def _prompt_toolset_checklist(platform_label: str, enabled: set[str]) -> set[str
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(labels)
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(labels)
|
||||
elif key == ord(" "):
|
||||
elif key == ord(' '):
|
||||
if cursor in selected:
|
||||
selected.discard(cursor)
|
||||
else:
|
||||
@@ -555,7 +526,7 @@ def _prompt_toolset_checklist(platform_label: str, enabled: set[str]) -> set[str
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = {CONFIGURABLE_TOOLSETS[i][0] for i in selected}
|
||||
return
|
||||
elif key in (27, ord("q")): # ESC or q
|
||||
elif key in (27, ord('q')): # ESC or q
|
||||
result_holder[0] = enabled
|
||||
return
|
||||
|
||||
@@ -594,10 +565,9 @@ def _prompt_toolset_checklist(platform_label: str, enabled: set[str]) -> set[str
|
||||
|
||||
# ─── Provider-Aware Configuration ────────────────────────────────────────────
|
||||
|
||||
|
||||
def _configure_toolset(ts_key: str, config: dict):
|
||||
"""Configure a toolset - provider selection + API keys.
|
||||
|
||||
|
||||
Uses TOOL_CATEGORIES for provider-aware config, falls back to simple
|
||||
env var prompts for toolsets not in TOOL_CATEGORIES.
|
||||
"""
|
||||
@@ -621,9 +591,7 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
req = cat["requires_python"]
|
||||
if sys.version_info < req:
|
||||
print()
|
||||
_print_error(
|
||||
f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})"
|
||||
)
|
||||
_print_error(f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})")
|
||||
_print_info(" Upgrade Python and reinstall to enable this tool.")
|
||||
return
|
||||
|
||||
@@ -642,7 +610,7 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
# Multiple providers - let user choose
|
||||
print()
|
||||
# Use custom title if provided (e.g. "Select Search Provider")
|
||||
title = cat.get("setup_title", "Choose a provider")
|
||||
title = cat.get("setup_title", f"Choose a provider")
|
||||
print(color(f" --- {icon} {name} - {title} ---", Colors.CYAN))
|
||||
if cat.get("setup_note"):
|
||||
_print_info(f" {cat['setup_note']}")
|
||||
@@ -658,11 +626,7 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
configured = " [active]"
|
||||
elif not env_vars:
|
||||
configured = (
|
||||
" [active]"
|
||||
if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "")
|
||||
else ""
|
||||
)
|
||||
configured = " [active]" if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") else ""
|
||||
else:
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
@@ -724,9 +688,9 @@ def _configure_provider(provider: dict, config: dict):
|
||||
|
||||
if value:
|
||||
save_env_value(var["key"], value)
|
||||
_print_success(" Saved")
|
||||
_print_success(f" Saved")
|
||||
else:
|
||||
_print_warning(" Skipped")
|
||||
_print_warning(f" Skipped")
|
||||
all_configured = False
|
||||
|
||||
# Run post-setup hooks if needed
|
||||
@@ -757,9 +721,9 @@ def _configure_simple_requirements(ts_key: str):
|
||||
value = _prompt(f" {var}", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(" Saved")
|
||||
_print_success(f" Saved")
|
||||
else:
|
||||
_print_warning(" Skipped")
|
||||
_print_warning(f" Skipped")
|
||||
|
||||
|
||||
def _reconfigure_tool(config: dict):
|
||||
@@ -863,9 +827,9 @@ def _reconfigure_provider(provider: dict, config: dict):
|
||||
value = _prompt(f" {var.get('prompt', var['key'])} (Enter to keep current)", password=not default_val)
|
||||
if value and value.strip():
|
||||
save_env_value(var["key"], value.strip())
|
||||
_print_success(" Updated")
|
||||
_print_success(f" Updated")
|
||||
else:
|
||||
_print_info(" Kept current")
|
||||
_print_info(f" Kept current")
|
||||
|
||||
|
||||
def _reconfigure_simple_requirements(ts_key: str):
|
||||
@@ -887,14 +851,13 @@ def _reconfigure_simple_requirements(ts_key: str):
|
||||
value = _prompt(f" {var} (Enter to keep current)", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(" Updated")
|
||||
_print_success(f" Updated")
|
||||
else:
|
||||
_print_info(" Kept current")
|
||||
_print_info(f" Kept current")
|
||||
|
||||
|
||||
# ─── Main Entry Point ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
"""Entry point for `hermes tools` and `hermes setup tools`.
|
||||
|
||||
@@ -944,8 +907,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
# TTS (Edge vs OpenAI vs ElevenLabs), etc. are shown even when
|
||||
# a free provider exists.
|
||||
to_configure = [
|
||||
ts_key
|
||||
for ts_key in sorted(new_enabled)
|
||||
ts_key for ts_key in sorted(new_enabled)
|
||||
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)
|
||||
]
|
||||
|
||||
@@ -1019,7 +981,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
|
||||
# Configure newly enabled toolsets that need API keys
|
||||
for ts_key in sorted(added):
|
||||
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key):
|
||||
if (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
|
||||
if not _toolset_has_keys(ts_key):
|
||||
_configure_toolset(ts_key, config)
|
||||
|
||||
|
||||
@@ -7,25 +7,23 @@ Provides options for:
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
def log_info(msg: str):
|
||||
print(f"{color('→', Colors.CYAN)} {msg}")
|
||||
|
||||
|
||||
def log_success(msg: str):
|
||||
print(f"{color('✓', Colors.GREEN)} {msg}")
|
||||
|
||||
|
||||
def log_warn(msg: str):
|
||||
print(f"{color('⚠', Colors.YELLOW)} {msg}")
|
||||
|
||||
|
||||
def log_error(msg: str):
|
||||
print(f"{color('✗', Colors.RED)} {msg}")
|
||||
|
||||
@@ -44,7 +42,7 @@ def find_shell_configs() -> list:
|
||||
"""Find shell configuration files that might have PATH entries."""
|
||||
home = Path.home()
|
||||
configs = []
|
||||
|
||||
|
||||
candidates = [
|
||||
home / ".bashrc",
|
||||
home / ".bash_profile",
|
||||
@@ -52,11 +50,11 @@ def find_shell_configs() -> list:
|
||||
home / ".zshrc",
|
||||
home / ".zprofile",
|
||||
]
|
||||
|
||||
|
||||
for config in candidates:
|
||||
if config.exists():
|
||||
configs.append(config)
|
||||
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
@@ -64,45 +62,45 @@ def remove_path_from_shell_configs():
|
||||
"""Remove Hermes PATH entries from shell configuration files."""
|
||||
configs = find_shell_configs()
|
||||
removed_from = []
|
||||
|
||||
|
||||
for config_path in configs:
|
||||
try:
|
||||
content = config_path.read_text()
|
||||
original_content = content
|
||||
|
||||
|
||||
# Remove lines containing hermes-agent or hermes PATH entries
|
||||
new_lines = []
|
||||
skip_next = False
|
||||
|
||||
for line in content.split("\n"):
|
||||
|
||||
for line in content.split('\n'):
|
||||
# Skip the "# Hermes Agent" comment and following line
|
||||
if "# Hermes Agent" in line or "# hermes-agent" in line:
|
||||
if '# Hermes Agent' in line or '# hermes-agent' in line:
|
||||
skip_next = True
|
||||
continue
|
||||
if skip_next and ("hermes" in line.lower() and "PATH" in line):
|
||||
if skip_next and ('hermes' in line.lower() and 'PATH' in line):
|
||||
skip_next = False
|
||||
continue
|
||||
skip_next = False
|
||||
|
||||
|
||||
# Remove any PATH line containing hermes
|
||||
if "hermes" in line.lower() and ("PATH=" in line or "path=" in line.lower()):
|
||||
if 'hermes' in line.lower() and ('PATH=' in line or 'path=' in line.lower()):
|
||||
continue
|
||||
|
||||
|
||||
new_lines.append(line)
|
||||
|
||||
new_content = "\n".join(new_lines)
|
||||
|
||||
|
||||
new_content = '\n'.join(new_lines)
|
||||
|
||||
# Clean up multiple blank lines
|
||||
while "\n\n\n" in new_content:
|
||||
new_content = new_content.replace("\n\n\n", "\n\n")
|
||||
|
||||
while '\n\n\n' in new_content:
|
||||
new_content = new_content.replace('\n\n\n', '\n\n')
|
||||
|
||||
if new_content != original_content:
|
||||
config_path.write_text(new_content)
|
||||
removed_from.append(config_path)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log_warn(f"Could not update {config_path}: {e}")
|
||||
|
||||
|
||||
return removed_from
|
||||
|
||||
|
||||
@@ -112,49 +110,61 @@ def remove_wrapper_script():
|
||||
Path.home() / ".local" / "bin" / "hermes",
|
||||
Path("/usr/local/bin/hermes"),
|
||||
]
|
||||
|
||||
|
||||
removed = []
|
||||
for wrapper in wrapper_paths:
|
||||
if wrapper.exists():
|
||||
try:
|
||||
# Check if it's our wrapper (contains hermes_cli reference)
|
||||
content = wrapper.read_text()
|
||||
if "hermes_cli" in content or "hermes-agent" in content:
|
||||
if 'hermes_cli' in content or 'hermes-agent' in content:
|
||||
wrapper.unlink()
|
||||
removed.append(wrapper)
|
||||
except Exception as e:
|
||||
log_warn(f"Could not remove {wrapper}: {e}")
|
||||
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def uninstall_gateway_service():
|
||||
"""Stop and uninstall the gateway service if running."""
|
||||
import platform
|
||||
|
||||
|
||||
if platform.system() != "Linux":
|
||||
return False
|
||||
|
||||
|
||||
service_file = Path.home() / ".config" / "systemd" / "user" / "hermes-gateway.service"
|
||||
|
||||
|
||||
if not service_file.exists():
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Stop the service
|
||||
subprocess.run(["systemctl", "--user", "stop", "hermes-gateway"], capture_output=True, check=False)
|
||||
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "stop", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
# Disable the service
|
||||
subprocess.run(["systemctl", "--user", "disable", "hermes-gateway"], capture_output=True, check=False)
|
||||
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "disable", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
# Remove service file
|
||||
service_file.unlink()
|
||||
|
||||
|
||||
# Reload systemd
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], capture_output=True, check=False)
|
||||
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log_warn(f"Could not fully remove gateway service: {e}")
|
||||
return False
|
||||
@@ -163,20 +173,20 @@ def uninstall_gateway_service():
|
||||
def run_uninstall(args):
|
||||
"""
|
||||
Run the uninstall process.
|
||||
|
||||
|
||||
Options:
|
||||
- Full uninstall: removes code + ~/.hermes/ (configs, data, logs)
|
||||
- Keep data: removes code but keeps ~/.hermes/ for future reinstall
|
||||
"""
|
||||
project_root = get_project_root()
|
||||
hermes_home = get_hermes_home()
|
||||
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.MAGENTA, Colors.BOLD))
|
||||
print(color("│ ⚕ Hermes Agent Uninstaller │", Colors.MAGENTA, Colors.BOLD))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.MAGENTA, Colors.BOLD))
|
||||
print()
|
||||
|
||||
|
||||
# Show what will be affected
|
||||
print(color("Current Installation:", Colors.CYAN, Colors.BOLD))
|
||||
print(f" Code: {project_root}")
|
||||
@@ -184,7 +194,7 @@ def run_uninstall(args):
|
||||
print(f" Secrets: {hermes_home / '.env'}")
|
||||
print(f" Data: {hermes_home / 'cron/'}, {hermes_home / 'sessions/'}, {hermes_home / 'logs/'}")
|
||||
print()
|
||||
|
||||
|
||||
# Ask for confirmation
|
||||
print(color("Uninstall Options:", Colors.YELLOW, Colors.BOLD))
|
||||
print()
|
||||
@@ -196,21 +206,21 @@ def run_uninstall(args):
|
||||
print()
|
||||
print(" 3) " + color("Cancel", Colors.CYAN) + " - Don't uninstall")
|
||||
print()
|
||||
|
||||
|
||||
try:
|
||||
choice = input(color("Select option [1/2/3]: ", Colors.BOLD)).strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
print("Cancelled.")
|
||||
return
|
||||
|
||||
|
||||
if choice == "3" or choice.lower() in ("c", "cancel", "q", "quit", "n", "no"):
|
||||
print()
|
||||
print("Uninstall cancelled.")
|
||||
return
|
||||
|
||||
full_uninstall = choice == "2"
|
||||
|
||||
|
||||
full_uninstall = (choice == "2")
|
||||
|
||||
# Final confirmation
|
||||
print()
|
||||
if full_uninstall:
|
||||
@@ -218,7 +228,7 @@ def run_uninstall(args):
|
||||
print(color(" Including: configs, API keys, sessions, scheduled jobs, logs", Colors.RED))
|
||||
else:
|
||||
print("This will remove the Hermes code but keep your configuration and data.")
|
||||
|
||||
|
||||
print()
|
||||
try:
|
||||
confirm = input(f"Type '{color('yes', Colors.YELLOW)}' to confirm: ").strip().lower()
|
||||
@@ -226,23 +236,23 @@ def run_uninstall(args):
|
||||
print()
|
||||
print("Cancelled.")
|
||||
return
|
||||
|
||||
|
||||
if confirm != "yes":
|
||||
print()
|
||||
print("Uninstall cancelled.")
|
||||
return
|
||||
|
||||
|
||||
print()
|
||||
print(color("Uninstalling...", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
|
||||
|
||||
# 1. Stop and uninstall gateway service
|
||||
log_info("Checking for gateway service...")
|
||||
if uninstall_gateway_service():
|
||||
log_success("Gateway service stopped and removed")
|
||||
else:
|
||||
log_info("No gateway service found")
|
||||
|
||||
|
||||
# 2. Remove PATH entries from shell configs
|
||||
log_info("Removing PATH entries from shell configs...")
|
||||
removed_configs = remove_path_from_shell_configs()
|
||||
@@ -251,7 +261,7 @@ def run_uninstall(args):
|
||||
log_success(f"Updated {config}")
|
||||
else:
|
||||
log_info("No PATH entries found to remove")
|
||||
|
||||
|
||||
# 3. Remove wrapper script
|
||||
log_info("Removing hermes command...")
|
||||
removed_wrappers = remove_wrapper_script()
|
||||
@@ -260,10 +270,10 @@ def run_uninstall(args):
|
||||
log_success(f"Removed {wrapper}")
|
||||
else:
|
||||
log_info("No wrapper script found")
|
||||
|
||||
|
||||
# 4. Remove installation directory (code)
|
||||
log_info("Removing installation directory...")
|
||||
|
||||
log_info(f"Removing installation directory...")
|
||||
|
||||
# Check if we're running from within the install dir
|
||||
# We need to be careful here
|
||||
try:
|
||||
@@ -279,7 +289,7 @@ def run_uninstall(args):
|
||||
except Exception as e:
|
||||
log_warn(f"Could not fully remove {project_root}: {e}")
|
||||
log_info("You may need to manually remove it")
|
||||
|
||||
|
||||
# 5. Optionally remove ~/.hermes/ data directory
|
||||
if full_uninstall:
|
||||
log_info("Removing configuration and data...")
|
||||
@@ -292,27 +302,22 @@ def run_uninstall(args):
|
||||
log_info("You may need to manually remove it")
|
||||
else:
|
||||
log_info(f"Keeping configuration and data in {hermes_home}")
|
||||
|
||||
|
||||
# Done
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.GREEN, Colors.BOLD))
|
||||
print(color("│ ✓ Uninstall Complete! │", Colors.GREEN, Colors.BOLD))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.GREEN, Colors.BOLD))
|
||||
print()
|
||||
|
||||
|
||||
if not full_uninstall:
|
||||
print(color("Your configuration and data have been preserved:", Colors.CYAN))
|
||||
print(f" {hermes_home}/")
|
||||
print()
|
||||
print("To reinstall later with your existing settings:")
|
||||
print(
|
||||
color(
|
||||
" curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash",
|
||||
Colors.DIM,
|
||||
)
|
||||
)
|
||||
print(color(" curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash", Colors.DIM))
|
||||
print()
|
||||
|
||||
|
||||
print(color("Reload your shell to complete the process:", Colors.YELLOW))
|
||||
print(" source ~/.bashrc # or ~/.zshrc")
|
||||
print()
|
||||
|
||||
@@ -19,7 +19,8 @@ import os
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
|
||||
DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
|
||||
|
||||
@@ -155,7 +156,8 @@ class SessionDB:
|
||||
# since the title column is guaranteed to exist at this point)
|
||||
try:
|
||||
cursor.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique ON sessions(title) WHERE title IS NOT NULL"
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
|
||||
"ON sessions(title) WHERE title IS NOT NULL"
|
||||
)
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
@@ -183,7 +185,7 @@ class SessionDB:
|
||||
session_id: str,
|
||||
source: str,
|
||||
model: str = None,
|
||||
model_config: dict[str, Any] = None,
|
||||
model_config: Dict[str, Any] = None,
|
||||
system_prompt: str = None,
|
||||
user_id: str = None,
|
||||
parent_session_id: str = None,
|
||||
@@ -223,7 +225,9 @@ class SessionDB:
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def update_token_counts(self, session_id: str, input_tokens: int = 0, output_tokens: int = 0) -> None:
|
||||
def update_token_counts(
|
||||
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0
|
||||
) -> None:
|
||||
"""Increment token counters on a session."""
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
@@ -234,9 +238,11 @@ class SessionDB:
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_session(self, session_id: str) -> dict[str, Any] | None:
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a session by ID."""
|
||||
cursor = self._conn.execute("SELECT * FROM sessions WHERE id = ?", (session_id,))
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
@@ -244,7 +250,7 @@ class SessionDB:
|
||||
MAX_TITLE_LENGTH = 100
|
||||
|
||||
@staticmethod
|
||||
def sanitize_title(title: str | None) -> str | None:
|
||||
def sanitize_title(title: Optional[str]) -> Optional[str]:
|
||||
"""Validate and sanitize a session title.
|
||||
|
||||
- Strips leading/trailing whitespace
|
||||
@@ -265,26 +271,27 @@ class SessionDB:
|
||||
# Remove ASCII control characters (0x00-0x1F, 0x7F) but keep
|
||||
# whitespace chars (\t=0x09, \n=0x0A, \r=0x0D) so they can be
|
||||
# normalized to spaces by the whitespace collapsing step below
|
||||
cleaned = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", title)
|
||||
cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title)
|
||||
|
||||
# Remove problematic Unicode control characters:
|
||||
# - Zero-width chars (U+200B-U+200F, U+FEFF)
|
||||
# - Directional overrides (U+202A-U+202E, U+2066-U+2069)
|
||||
# - Object replacement (U+FFFC), interlinear annotation (U+FFF9-U+FFFB)
|
||||
cleaned = re.sub(
|
||||
r"[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]",
|
||||
"",
|
||||
cleaned,
|
||||
r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]',
|
||||
'', cleaned,
|
||||
)
|
||||
|
||||
# Collapse internal whitespace runs and strip
|
||||
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
||||
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
|
||||
|
||||
if not cleaned:
|
||||
return None
|
||||
|
||||
if len(cleaned) > SessionDB.MAX_TITLE_LENGTH:
|
||||
raise ValueError(f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})")
|
||||
raise ValueError(
|
||||
f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})"
|
||||
)
|
||||
|
||||
return cleaned
|
||||
|
||||
@@ -305,7 +312,9 @@ class SessionDB:
|
||||
)
|
||||
conflict = cursor.fetchone()
|
||||
if conflict:
|
||||
raise ValueError(f"Title '{title}' is already in use by session {conflict['id']}")
|
||||
raise ValueError(
|
||||
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE id = ?",
|
||||
(title, session_id),
|
||||
@@ -313,19 +322,23 @@ class SessionDB:
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_session_title(self, session_id: str) -> str | None:
|
||||
def get_session_title(self, session_id: str) -> Optional[str]:
|
||||
"""Get the title for a session, or None."""
|
||||
cursor = self._conn.execute("SELECT title FROM sessions WHERE id = ?", (session_id,))
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row["title"] if row else None
|
||||
|
||||
def get_session_by_title(self, title: str) -> dict[str, Any] | None:
|
||||
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
|
||||
"""Look up a session by exact title. Returns session dict or None."""
|
||||
cursor = self._conn.execute("SELECT * FROM sessions WHERE title = ?", (title,))
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE title = ?", (title,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def resolve_session_by_title(self, title: str) -> str | None:
|
||||
def resolve_session_by_title(self, title: str) -> Optional[str]:
|
||||
"""Resolve a title to a session ID, preferring the latest in a lineage.
|
||||
|
||||
If the exact title exists, returns that session's ID.
|
||||
@@ -340,7 +353,8 @@ class SessionDB:
|
||||
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
|
||||
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id, title, started_at FROM sessions WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||
"SELECT id, title, started_at FROM sessions "
|
||||
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||
(f"{escaped} #%",),
|
||||
)
|
||||
numbered = cursor.fetchall()
|
||||
@@ -359,9 +373,8 @@ class SessionDB:
|
||||
the highest existing number and increments.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Strip existing #N suffix to find the true base
|
||||
match = re.match(r"^(.*?) #(\d+)$", base_title)
|
||||
match = re.match(r'^(.*?) #(\d+)$', base_title)
|
||||
if match:
|
||||
base = match.group(1)
|
||||
else:
|
||||
@@ -382,7 +395,7 @@ class SessionDB:
|
||||
# Find the highest number
|
||||
max_num = 1 # The unnumbered original counts as #1
|
||||
for t in existing:
|
||||
m = re.match(r"^.* #(\d+)$", t)
|
||||
m = re.match(r'^.* #(\d+)$', t)
|
||||
if m:
|
||||
max_num = max(max_num, int(m.group(1)))
|
||||
|
||||
@@ -393,7 +406,7 @@ class SessionDB:
|
||||
source: str = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List sessions with preview (first user message) and last active timestamp.
|
||||
|
||||
Returns dicts with keys: id, source, model, title, started_at, ended_at,
|
||||
@@ -493,7 +506,7 @@ class SessionDB:
|
||||
self._conn.commit()
|
||||
return msg_id
|
||||
|
||||
def get_messages(self, session_id: str) -> list[dict[str, Any]]:
|
||||
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""Load all messages for a session, ordered by timestamp."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
|
||||
@@ -511,7 +524,7 @@ class SessionDB:
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
|
||||
def get_messages_as_conversation(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load messages in the OpenAI conversation format (role + content dicts).
|
||||
Used by the gateway to restore conversation history.
|
||||
@@ -543,11 +556,11 @@ class SessionDB:
|
||||
def search_messages(
|
||||
self,
|
||||
query: str,
|
||||
source_filter: list[str] = None,
|
||||
role_filter: list[str] = None,
|
||||
source_filter: List[str] = None,
|
||||
role_filter: List[str] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Full-text search across session messages using FTS5.
|
||||
|
||||
@@ -615,7 +628,8 @@ class SessionDB:
|
||||
(match["session_id"], match["id"], match["id"]),
|
||||
)
|
||||
context_msgs = [
|
||||
{"role": r["role"], "content": (r["content"] or "")[:200]} for r in ctx_cursor.fetchall()
|
||||
{"role": r["role"], "content": (r["content"] or "")[:200]}
|
||||
for r in ctx_cursor.fetchall()
|
||||
]
|
||||
match["context"] = context_msgs
|
||||
except Exception:
|
||||
@@ -631,7 +645,7 @@ class SessionDB:
|
||||
source: str = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List sessions, optionally filtered by source."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
@@ -652,7 +666,9 @@ class SessionDB:
|
||||
def session_count(self, source: str = None) -> int:
|
||||
"""Count sessions, optionally filtered by source."""
|
||||
if source:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions WHERE source = ?", (source,))
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE source = ?", (source,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions")
|
||||
return cursor.fetchone()[0]
|
||||
@@ -660,7 +676,9 @@ class SessionDB:
|
||||
def message_count(self, session_id: str = None) -> int:
|
||||
"""Count messages, optionally for a specific session."""
|
||||
if session_id:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,))
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM messages")
|
||||
return cursor.fetchone()[0]
|
||||
@@ -669,7 +687,7 @@ class SessionDB:
|
||||
# Export and cleanup
|
||||
# =========================================================================
|
||||
|
||||
def export_session(self, session_id: str) -> dict[str, Any] | None:
|
||||
def export_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Export a single session with all its messages as a dict."""
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
@@ -677,7 +695,7 @@ class SessionDB:
|
||||
messages = self.get_messages(session_id)
|
||||
return {**session, "messages": messages}
|
||||
|
||||
def export_all(self, source: str = None) -> list[dict[str, Any]]:
|
||||
def export_all(self, source: str = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Export all sessions (with messages) as a list of dicts.
|
||||
Suitable for writing to a JSONL file for backup/analysis.
|
||||
@@ -691,7 +709,9 @@ class SessionDB:
|
||||
|
||||
def clear_messages(self, session_id: str) -> None:
|
||||
"""Delete all messages for a session and reset its counters."""
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
self._conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
||||
(session_id,),
|
||||
@@ -700,7 +720,9 @@ class SessionDB:
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session and all its messages. Returns True if found."""
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,))
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return False
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
@@ -714,7 +736,6 @@ class SessionDB:
|
||||
Only prunes ended sessions (not active ones).
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
cutoff = _time.time() - (older_than_days * 86400)
|
||||
|
||||
if source:
|
||||
|
||||
@@ -20,10 +20,11 @@ Public API (signatures preserved from the original 2,400-line version):
|
||||
check_tool_availability(quiet) -> tuple
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
from tools.registry import registry
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
@@ -35,7 +36,6 @@ logger = logging.getLogger(__name__)
|
||||
# Async Bridging (single source of truth -- used by registry.dispatch too)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from a sync context.
|
||||
|
||||
@@ -56,7 +56,6 @@ def _run_async(coro):
|
||||
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, coro)
|
||||
return future.result(timeout=300)
|
||||
@@ -67,7 +66,6 @@ def _run_async(coro):
|
||||
# Tool Discovery (importing each module triggers its registry.register calls)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _discover_tools():
|
||||
"""Import all tool modules to trigger their registry.register() calls.
|
||||
|
||||
@@ -99,7 +97,6 @@ def _discover_tools():
|
||||
"tools.homeassistant_tool",
|
||||
]
|
||||
import importlib
|
||||
|
||||
for mod_name in _modules:
|
||||
try:
|
||||
importlib.import_module(mod_name)
|
||||
@@ -112,7 +109,6 @@ _discover_tools()
|
||||
# MCP tool discovery (external MCP servers from config)
|
||||
try:
|
||||
from tools.mcp_tool import discover_mcp_tools
|
||||
|
||||
discover_mcp_tools()
|
||||
except Exception as e:
|
||||
logger.debug("MCP tool discovery failed: %s", e)
|
||||
@@ -122,13 +118,13 @@ except Exception as e:
|
||||
# Backward-compat constants (built once after discovery)
|
||||
# =============================================================================
|
||||
|
||||
TOOL_TO_TOOLSET_MAP: dict[str, str] = registry.get_tool_to_toolset_map()
|
||||
TOOL_TO_TOOLSET_MAP: Dict[str, str] = registry.get_tool_to_toolset_map()
|
||||
|
||||
TOOLSET_REQUIREMENTS: dict[str, dict] = registry.get_toolset_requirements()
|
||||
TOOLSET_REQUIREMENTS: Dict[str, dict] = registry.get_toolset_requirements()
|
||||
|
||||
# Resolved tool names from the last get_tool_definitions() call.
|
||||
# Used by code_execution_tool to know which tools are available in this session.
|
||||
_last_resolved_tool_names: list[str] = []
|
||||
_last_resolved_tool_names: List[str] = []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -143,29 +139,18 @@ _LEGACY_TOOLSET_MAP = {
|
||||
"image_tools": ["image_generate"],
|
||||
"skills_tools": ["skills_list", "skill_view", "skill_manage"],
|
||||
"browser_tools": [
|
||||
"browser_navigate",
|
||||
"browser_snapshot",
|
||||
"browser_click",
|
||||
"browser_type",
|
||||
"browser_scroll",
|
||||
"browser_back",
|
||||
"browser_press",
|
||||
"browser_close",
|
||||
"browser_get_images",
|
||||
"browser_vision",
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
],
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
|
||||
"rl_tools": [
|
||||
"rl_list_environments",
|
||||
"rl_select_environment",
|
||||
"rl_get_current_config",
|
||||
"rl_edit_config",
|
||||
"rl_start_training",
|
||||
"rl_check_status",
|
||||
"rl_stop_training",
|
||||
"rl_get_results",
|
||||
"rl_list_runs",
|
||||
"rl_test_inference",
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
],
|
||||
"file_tools": ["read_file", "write_file", "patch", "search_files"],
|
||||
"tts_tools": ["text_to_speech"],
|
||||
@@ -176,12 +161,11 @@ _LEGACY_TOOLSET_MAP = {
|
||||
# get_tool_definitions (the main schema provider)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_tool_definitions(
|
||||
enabled_toolsets: list[str] = None,
|
||||
disabled_toolsets: list[str] = None,
|
||||
enabled_toolsets: List[str] = None,
|
||||
disabled_toolsets: List[str] = None,
|
||||
quiet_mode: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for model API calls with toolset-based filtering.
|
||||
|
||||
@@ -216,7 +200,6 @@ def get_tool_definitions(
|
||||
|
||||
elif disabled_toolsets:
|
||||
from toolsets import get_all_toolsets
|
||||
|
||||
for ts_name in get_all_toolsets():
|
||||
tools_to_include.update(resolve_toolset(ts_name))
|
||||
|
||||
@@ -236,7 +219,6 @@ def get_tool_definitions(
|
||||
print(f"⚠️ Unknown toolset: {toolset_name}")
|
||||
else:
|
||||
from toolsets import get_all_toolsets
|
||||
|
||||
for ts_name in get_all_toolsets():
|
||||
tools_to_include.update(resolve_toolset(ts_name))
|
||||
|
||||
@@ -248,7 +230,6 @@ def get_tool_definitions(
|
||||
# execute_code" even when the user disabled the web toolset (#560-discord).
|
||||
if "execute_code" in tools_to_include:
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
|
||||
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled)
|
||||
for i, td in enumerate(filtered_tools):
|
||||
@@ -282,9 +263,9 @@ _AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
|
||||
|
||||
def handle_function_call(
|
||||
function_name: str,
|
||||
function_args: dict[str, Any],
|
||||
task_id: str | None = None,
|
||||
user_task: str | None = None,
|
||||
function_args: Dict[str, Any],
|
||||
task_id: Optional[str] = None,
|
||||
user_task: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Main function call dispatcher that routes calls to the tool registry.
|
||||
@@ -304,15 +285,13 @@ def handle_function_call(
|
||||
|
||||
if function_name == "execute_code":
|
||||
return registry.dispatch(
|
||||
function_name,
|
||||
function_args,
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
enabled_tools=_last_resolved_tool_names,
|
||||
)
|
||||
|
||||
return registry.dispatch(
|
||||
function_name,
|
||||
function_args,
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
user_task=user_task,
|
||||
)
|
||||
@@ -327,27 +306,26 @@ def handle_function_call(
|
||||
# Backward-compat wrapper functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_all_tool_names() -> list[str]:
|
||||
def get_all_tool_names() -> List[str]:
|
||||
"""Return all registered tool names."""
|
||||
return registry.get_all_tool_names()
|
||||
|
||||
|
||||
def get_toolset_for_tool(tool_name: str) -> str | None:
|
||||
def get_toolset_for_tool(tool_name: str) -> Optional[str]:
|
||||
"""Return the toolset a tool belongs to."""
|
||||
return registry.get_toolset_for_tool(tool_name)
|
||||
|
||||
|
||||
def get_available_toolsets() -> dict[str, dict]:
|
||||
def get_available_toolsets() -> Dict[str, dict]:
|
||||
"""Return toolset availability info for UI display."""
|
||||
return registry.get_available_toolsets()
|
||||
|
||||
|
||||
def check_toolset_requirements() -> dict[str, bool]:
|
||||
def check_toolset_requirements() -> Dict[str, bool]:
|
||||
"""Return {toolset: available_bool} for every registered toolset."""
|
||||
return registry.check_toolset_requirements()
|
||||
|
||||
|
||||
def check_tool_availability(quiet: bool = False) -> tuple[list[str], list[dict]]:
|
||||
def check_tool_availability(quiet: bool = False) -> Tuple[List[str], List[dict]]:
|
||||
"""Return (available_toolsets, unavailable_info)."""
|
||||
return registry.check_tool_availability(quiet=quiet)
|
||||
|
||||
@@ -40,7 +40,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
modal = ["swe-rex[modal]>=1.4.0"]
|
||||
daytona = ["daytona>=0.148.0"]
|
||||
dev = ["pytest", "pytest-asyncio", "mcp>=1.2.0", "ruff", "pre-commit", "watchfiles"]
|
||||
dev = ["pytest", "pytest-asyncio"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
cron = ["croniter"]
|
||||
slack = ["slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
@@ -76,46 +76,6 @@ py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajector
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron", "honcho_integration"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "I", "UP", "B", "SIM"]
|
||||
ignore = [
|
||||
"E402", # late imports — intentional throughout codebase
|
||||
"E501", # line too long — handled by formatter where it can
|
||||
"E731", # lambda assignments — used in registry pattern
|
||||
"E741", # ambiguous variable name — existing patterns
|
||||
"F811", # redefined unused — intentional overrides
|
||||
"F841", # unused variable — cleanup separately
|
||||
"B007", # unused loop variable — cleanup separately
|
||||
"B904", # raise from — too noisy to gate on
|
||||
"B905", # zip strict — cleanup separately
|
||||
"B027", # empty method without abstract decorator
|
||||
"SIM102", # collapsible if — readability preference
|
||||
"SIM103", # needless bool — readability preference
|
||||
"SIM105", # suppressible exception — existing pattern
|
||||
"SIM108", # ternary — readability preference
|
||||
"SIM110", # reimplemented builtin
|
||||
"SIM112", # uncapitalized env var
|
||||
"SIM115", # open file with context handler
|
||||
"SIM117", # multiple with statements
|
||||
"SIM118", # in-dict-keys — cleanup separately
|
||||
"SIM212", # if-expr twisted arms
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"batch_runner.py" = ["F821"]
|
||||
"tools/patch_parser.py" = ["F821"]
|
||||
"gateway/run.py" = ["F821"]
|
||||
"gateway/channel_directory.py" = ["F401"]
|
||||
"hermes_cli/doctor.py" = ["F401"]
|
||||
"tools/image_generation_tool.py" = ["F401"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["tools", "hermes_cli", "gateway", "agent", "cron"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
markers = [
|
||||
|
||||
1550
run_agent.py
1550
run_agent.py
File diff suppressed because it is too large
Load Diff
161
skills/gaming/pokemon-player/SKILL.md
Normal file
161
skills/gaming/pokemon-player/SKILL.md
Normal file
@@ -0,0 +1,161 @@
|
||||
---
|
||||
name: pokemon-player
|
||||
description: Play Pokémon games autonomously via headless emulation. Starts a game server, reads structured game state from RAM, makes strategic decisions, and sends button inputs — all from the terminal.
|
||||
tags: [gaming, pokemon, emulator, pyboy, gameplay, gameboy]
|
||||
---
|
||||
# Pokémon Player
|
||||
|
||||
Play Pokémon games via headless emulation using the `pokemon-agent` package.
|
||||
|
||||
## When to Use
|
||||
- User says "play pokemon", "start pokemon", "pokemon game"
|
||||
- User asks about Pokemon Red, Blue, Yellow, FireRed, etc.
|
||||
- User wants to watch an AI play Pokemon
|
||||
- User references a ROM file (.gb, .gbc, .gba)
|
||||
|
||||
## First-Time Setup
|
||||
|
||||
### 1. Install the package
|
||||
```bash
|
||||
pip install pokemon-agent[dashboard] pyboy
|
||||
```
|
||||
|
||||
### 2. Get the ROM
|
||||
Ask the user for their ROM file path. Do NOT attempt to download ROMs.
|
||||
|
||||
### 3. Start the game server
|
||||
```bash
|
||||
pokemon-agent serve --rom <ROM_PATH> --port 8765 &
|
||||
```
|
||||
Wait 3 seconds, then verify:
|
||||
```bash
|
||||
curl -s http://localhost:8765/health
|
||||
```
|
||||
|
||||
## The Gameplay Loop
|
||||
|
||||
### Step 1: OBSERVE
|
||||
```bash
|
||||
curl -s http://localhost:8765/state
|
||||
```
|
||||
|
||||
### Step 2: ORIENT
|
||||
- Dialog active → advance text
|
||||
- In battle → fight
|
||||
- Party hurt → heal
|
||||
- Near objective → navigate
|
||||
|
||||
### Step 3: DECIDE
|
||||
Priority order:
|
||||
1. If dialog active → a_until_dialog_end
|
||||
2. If in battle → choose best move
|
||||
3. If any Pokemon <20% HP → Pokémon Center
|
||||
4. If near story objective → navigate to it
|
||||
5. If underleveled → train in grass
|
||||
6. Otherwise → explore
|
||||
|
||||
### Step 4: ACT
|
||||
```bash
|
||||
curl -s -X POST http://localhost:8765/action \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"actions": ["walk_up", "walk_up", "press_a"]}'
|
||||
```
|
||||
|
||||
Action reference:
|
||||
- press_a — confirm, talk, select
|
||||
- press_b — cancel, close menu
|
||||
- press_start — open game menu
|
||||
- walk_up/down/left/right — move one tile
|
||||
- a_until_dialog_end — advance all dialog
|
||||
- wait_60 — wait ~1 second
|
||||
|
||||
### Step 5: VERIFY
|
||||
Check state_after in the response. If stuck 3+ turns:
|
||||
1. Press B several times
|
||||
2. Try different directions
|
||||
3. Take screenshot and use vision_analyze
|
||||
4. Load last save if truly stuck
|
||||
|
||||
### Step 6: RECORD
|
||||
```
|
||||
memory add: PKM:OBJECTIVE: Heading to Pewter City to challenge Brock
|
||||
memory add: PKM:PROGRESS: Got Squirtle, Got Pokedex, → Pewter City
|
||||
```
|
||||
|
||||
### Step 7: SAVE
|
||||
Save every 20-30 turns and ALWAYS before gym battles:
|
||||
```bash
|
||||
curl -s -X POST http://localhost:8765/save \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"name": "before_brock"}'
|
||||
```
|
||||
|
||||
## Battle Strategy
|
||||
|
||||
### Decision Tree
|
||||
1. Want to catch? → Weaken then throw Poké Ball
|
||||
2. Wild you don't need? → RUN
|
||||
3. Type advantage? → Use super-effective move
|
||||
4. No advantage? → Use strongest STAB move
|
||||
5. Low HP? → Switch or use Potion
|
||||
|
||||
### Type Chart
|
||||
- Water beats Fire, Ground, Rock
|
||||
- Fire beats Grass, Bug, Ice
|
||||
- Grass beats Water, Ground, Rock
|
||||
- Electric beats Water, Flying
|
||||
- Ground beats Fire, Electric, Rock, Poison
|
||||
- Psychic beats Fighting, Poison (dominant in Gen 1!)
|
||||
|
||||
### Gen 1 Quirks
|
||||
- Special stat is both offense AND defense for special moves
|
||||
- Psychic is overpowered (Ghost moves bugged)
|
||||
- Critical hits based on Speed stat
|
||||
- Wrap/Bind prevent opponent from acting
|
||||
|
||||
## Memory Conventions
|
||||
| Prefix | Purpose | Example |
|
||||
|--------|---------|---------|
|
||||
| PKM:OBJECTIVE | Current goal | Defeat Brock in Pewter City |
|
||||
| PKM:MAP | Navigation knowledge | Viridian Forest: go north |
|
||||
| PKM:STRATEGY | Battle/team plans | Need Grass type before Misty |
|
||||
| PKM:PROGRESS | Milestone tracker | ✓ Boulder Badge → Cascade Badge |
|
||||
| PKM:STUCK | Stuck situations | Got stuck in Cerulean Cave |
|
||||
| PKM:TEAM | Team notes | Squirtle is Water/Ice coverage |
|
||||
|
||||
## Progression Milestones
|
||||
- ☐ Choose starter
|
||||
- ☐ Deliver Oak's Parcel → receive Pokédex
|
||||
- ☐ Boulder Badge — Brock (Rock) → use Water/Grass
|
||||
- ☐ Cascade Badge — Misty (Water) → use Grass/Electric
|
||||
- ☐ Thunder Badge — Lt. Surge (Electric) → use Ground
|
||||
- ☐ Rainbow Badge — Erika (Grass) → use Fire/Ice/Flying
|
||||
- ☐ Soul Badge — Koga (Poison) → use Ground/Psychic
|
||||
- ☐ Marsh Badge — Sabrina (Psychic)
|
||||
- ☐ Volcano Badge — Blaine (Fire) → use Water/Ground
|
||||
- ☐ Earth Badge — Giovanni (Ground) → use Water/Grass/Ice
|
||||
- ☐ Elite Four → Champion!
|
||||
|
||||
## Stopping Play
|
||||
1. Save the game:
|
||||
```bash
|
||||
curl -s -X POST http://localhost:8765/save \
|
||||
-d '{"name": "session_end"}'
|
||||
```
|
||||
2. Update memory with progress
|
||||
3. Tell user: "Game saved! Say 'play pokemon' to resume."
|
||||
4. Kill the background server process
|
||||
|
||||
## Dashboard
|
||||
If `pokemon-agent[dashboard]` is installed, open:
|
||||
http://localhost:8765/dashboard
|
||||
|
||||
Live features: game screen, AI reasoning stream, team status, action log.
|
||||
|
||||
## Pitfalls
|
||||
- NEVER download or provide ROM files — always ask the user
|
||||
- Don't send more than 15 actions per /action call
|
||||
- Always wait for dialog to clear before moving
|
||||
- Save BEFORE gym battles
|
||||
- Take screenshots sparingly — they cost vision tokens
|
||||
- Verify server is running with /health before any commands
|
||||
@@ -176,18 +176,14 @@ class TestVisionClientFallback:
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
|
||||
def test_vision_auto_falls_back_to_custom_endpoint(self, monkeypatch):
|
||||
"""Custom endpoint is used as fallback in vision auto mode.
|
||||
|
||||
Many local models (Qwen-VL, LLaVA, etc.) support vision.
|
||||
When no OpenRouter/Nous/Codex is available, try the custom endpoint.
|
||||
"""
|
||||
def test_vision_auto_skips_custom_endpoint(self, monkeypatch):
|
||||
"""Custom endpoint is skipped in vision auto mode."""
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:1234/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is not None # Custom endpoint picked up as fallback
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_uses_openrouter_when_available(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
@@ -1,532 +0,0 @@
|
||||
"""
|
||||
Tests for Slack platform adapter.
|
||||
|
||||
Covers: app_mention handler, send_document, send_video,
|
||||
incoming document handling, message routing.
|
||||
|
||||
Note: slack-bolt may not be installed in the test environment.
|
||||
We mock the slack modules at import time to avoid collection errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock the slack-bolt package if it's not installed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_slack_mock():
|
||||
"""Install mock slack modules so SlackAdapter can be imported."""
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return # Real library installed
|
||||
|
||||
slack_bolt = MagicMock()
|
||||
slack_bolt.async_app.AsyncApp = MagicMock
|
||||
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
|
||||
|
||||
slack_sdk = MagicMock()
|
||||
slack_sdk.web.async_client.AsyncWebClient = MagicMock
|
||||
|
||||
for name, mod in [
|
||||
("slack_bolt", slack_bolt),
|
||||
("slack_bolt.async_app", slack_bolt.async_app),
|
||||
("slack_bolt.adapter", slack_bolt.adapter),
|
||||
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
|
||||
("slack_bolt.adapter.socket_mode.async_handler", slack_bolt.adapter.socket_mode.async_handler),
|
||||
("slack_sdk", slack_sdk),
|
||||
("slack_sdk.web", slack_sdk.web),
|
||||
("slack_sdk.web.async_client", slack_sdk.web.async_client),
|
||||
]:
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_ensure_slack_mock()
|
||||
|
||||
# Patch SLACK_AVAILABLE before importing the adapter
|
||||
import gateway.platforms.slack as _slack_mod
|
||||
_slack_mod.SLACK_AVAILABLE = True
|
||||
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter():
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake-token")
|
||||
a = SlackAdapter(config)
|
||||
# Mock the Slack app client
|
||||
a._app = MagicMock()
|
||||
a._app.client = AsyncMock()
|
||||
a._bot_user_id = "U_BOT"
|
||||
a._running = True
|
||||
# Capture events instead of processing them
|
||||
a.handle_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestAppMentionHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAppMentionHandler:
|
||||
"""Verify that the app_mention event handler is registered."""
|
||||
|
||||
def test_app_mention_registered_on_connect(self):
|
||||
"""connect() should register both 'message' and 'app_mention' handlers."""
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake")
|
||||
adapter = SlackAdapter(config)
|
||||
|
||||
# Track which events get registered
|
||||
registered_events = []
|
||||
registered_commands = []
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
def mock_event(event_type):
|
||||
def decorator(fn):
|
||||
registered_events.append(event_type)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
def mock_command(cmd):
|
||||
def decorator(fn):
|
||||
registered_commands.append(cmd)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
mock_app.event = mock_event
|
||||
mock_app.command = mock_command
|
||||
mock_app.client = AsyncMock()
|
||||
mock_app.client.auth_test = AsyncMock(return_value={
|
||||
"user_id": "U_BOT",
|
||||
"user": "testbot",
|
||||
})
|
||||
|
||||
with patch.object(_slack_mod, "AsyncApp", return_value=mock_app), \
|
||||
patch.object(_slack_mod, "AsyncSocketModeHandler", return_value=MagicMock()), \
|
||||
patch.dict(os.environ, {"SLACK_APP_TOKEN": "xapp-fake"}), \
|
||||
patch("asyncio.create_task"):
|
||||
asyncio.get_event_loop().run_until_complete(adapter.connect())
|
||||
|
||||
assert "message" in registered_events
|
||||
assert "app_mention" in registered_events
|
||||
assert "/hermes" in registered_commands
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendDocument
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendDocument:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_success(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "report.pdf"
|
||||
test_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
caption="Here's the report",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
adapter._app.client.files_upload_v2.assert_called_once()
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["channel"] == "C123"
|
||||
assert call_kwargs["file"] == str(test_file)
|
||||
assert call_kwargs["filename"] == "report.pdf"
|
||||
assert call_kwargs["initial_comment"] == "Here's the report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_custom_name(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "data.csv"
|
||||
test_file.write_bytes(b"a,b,c\n1,2,3")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
file_name="quarterly-report.csv",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["filename"] == "quarterly-report.csv"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_missing_file(self, adapter):
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path="/nonexistent/file.pdf",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_not_connected(self, adapter):
|
||||
adapter._app = None
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path="/some/file.pdf",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_api_error_falls_back(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "doc.pdf"
|
||||
test_file.write_bytes(b"content")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=RuntimeError("Slack API error")
|
||||
)
|
||||
|
||||
# Should fall back to base class (text message)
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
)
|
||||
|
||||
# Base class send() is also mocked, so check it was attempted
|
||||
adapter._app.client.chat_postMessage.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_with_thread(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "notes.txt"
|
||||
test_file.write_bytes(b"some notes")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
reply_to="1234567890.123456",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["thread_ts"] == "1234567890.123456"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendVideo
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendVideo:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_success(self, adapter, tmp_path):
|
||||
video = tmp_path / "clip.mp4"
|
||||
video.write_bytes(b"fake video data")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path=str(video),
|
||||
caption="Check this out",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["filename"] == "clip.mp4"
|
||||
assert call_kwargs["initial_comment"] == "Check this out"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_missing_file(self, adapter):
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path="/nonexistent/video.mp4",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_not_connected(self, adapter):
|
||||
adapter._app = None
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path="/some/video.mp4",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_api_error_falls_back(self, adapter, tmp_path):
|
||||
video = tmp_path / "clip.mp4"
|
||||
video.write_bytes(b"fake video")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=RuntimeError("Slack API error")
|
||||
)
|
||||
|
||||
# Should fall back to base class (text message)
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path=str(video),
|
||||
)
|
||||
|
||||
adapter._app.client.chat_postMessage.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestIncomingDocumentHandling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIncomingDocumentHandling:
|
||||
def _make_event(self, files=None, text="hello", channel_type="im"):
|
||||
"""Build a mock Slack message event with file attachments."""
|
||||
return {
|
||||
"text": text,
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": channel_type,
|
||||
"ts": "1234567890.000001",
|
||||
"files": files or [],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_document_cached(self, adapter):
|
||||
"""A PDF attachment should be downloaded, cached, and set as DOCUMENT type."""
|
||||
pdf_bytes = b"%PDF-1.4 fake content"
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = pdf_bytes
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/pdf",
|
||||
"name": "report.pdf",
|
||||
"url_private_download": "https://files.slack.com/report.pdf",
|
||||
"size": len(pdf_bytes),
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.DOCUMENT
|
||||
assert len(msg_event.media_urls) == 1
|
||||
assert os.path.exists(msg_event.media_urls[0])
|
||||
assert msg_event.media_types == ["application/pdf"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_txt_document_injects_content(self, adapter):
|
||||
"""A .txt file under 100KB should have its content injected into event text."""
|
||||
content = b"Hello from a text file"
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = content
|
||||
event = self._make_event(
|
||||
text="summarize this",
|
||||
files=[{
|
||||
"mimetype": "text/plain",
|
||||
"name": "notes.txt",
|
||||
"url_private_download": "https://files.slack.com/notes.txt",
|
||||
"size": len(content),
|
||||
}],
|
||||
)
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "Hello from a text file" in msg_event.text
|
||||
assert "[Content of notes.txt]" in msg_event.text
|
||||
assert "summarize this" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_md_document_injects_content(self, adapter):
|
||||
"""A .md file under 100KB should have its content injected."""
|
||||
content = b"# Title\nSome markdown content"
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = content
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "text/markdown",
|
||||
"name": "readme.md",
|
||||
"url_private_download": "https://files.slack.com/readme.md",
|
||||
"size": len(content),
|
||||
}], text="")
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "# Title" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_txt_not_injected(self, adapter):
|
||||
"""A .txt file over 100KB should be cached but NOT injected."""
|
||||
content = b"x" * (200 * 1024)
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = content
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "text/plain",
|
||||
"name": "big.txt",
|
||||
"url_private_download": "https://files.slack.com/big.txt",
|
||||
"size": len(content),
|
||||
}], text="")
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert len(msg_event.media_urls) == 1
|
||||
assert "[Content of" not in (msg_event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_file_type_skipped(self, adapter):
|
||||
"""A .zip file should be silently skipped."""
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/zip",
|
||||
"name": "archive.zip",
|
||||
"url_private_download": "https://files.slack.com/archive.zip",
|
||||
"size": 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.TEXT
|
||||
assert len(msg_event.media_urls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_document_skipped(self, adapter):
|
||||
"""A document over 20MB should be skipped."""
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/pdf",
|
||||
"name": "huge.pdf",
|
||||
"url_private_download": "https://files.slack.com/huge.pdf",
|
||||
"size": 25 * 1024 * 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert len(msg_event.media_urls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_download_error_handled(self, adapter):
|
||||
"""If document download fails, handler should not crash."""
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.side_effect = RuntimeError("download failed")
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/pdf",
|
||||
"name": "report.pdf",
|
||||
"url_private_download": "https://files.slack.com/report.pdf",
|
||||
"size": 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
# Handler should still be called (the exception is caught)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_still_handled(self, adapter):
|
||||
"""Image attachments should still go through the image path, not document."""
|
||||
with patch.object(adapter, "_download_slack_file", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = "/tmp/cached_image.jpg"
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "image/jpeg",
|
||||
"name": "photo.jpg",
|
||||
"url_private_download": "https://files.slack.com/photo.jpg",
|
||||
"size": 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.PHOTO
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMessageRouting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMessageRouting:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_processed_without_mention(self, adapter):
|
||||
"""DM messages should be processed without requiring a bot mention."""
|
||||
event = {
|
||||
"text": "hello",
|
||||
"user": "U_USER",
|
||||
"channel": "D123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_message_requires_mention(self, adapter):
|
||||
"""Channel messages without a bot mention should be ignored."""
|
||||
event = {
|
||||
"text": "just talking",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "channel",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_mention_strips_bot_id(self, adapter):
|
||||
"""When mentioned in a channel, the bot mention should be stripped."""
|
||||
event = {
|
||||
"text": "<@U_BOT> what's the weather?",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "channel",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.text == "what's the weather?"
|
||||
assert "<@U_BOT>" not in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bot_messages_ignored(self, adapter):
|
||||
"""Messages from bots should be ignored."""
|
||||
event = {
|
||||
"text": "bot response",
|
||||
"bot_id": "B_OTHER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_edits_ignored(self, adapter):
|
||||
"""Message edits should be ignored."""
|
||||
event = {
|
||||
"text": "edited message",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
"subtype": "message_changed",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
@@ -505,25 +505,6 @@ class TestExpandPath:
|
||||
assert result == str(Path.home())
|
||||
_assert_clean(result)
|
||||
|
||||
def test_tilde_injection_blocked(self, ops):
|
||||
"""Paths like ~; rm -rf / must NOT execute shell commands."""
|
||||
malicious = "~; echo PWNED > /tmp/_hermes_injection_test"
|
||||
result = ops._expand_path(malicious)
|
||||
# The invalid username (contains ";") should prevent shell expansion.
|
||||
# The path should be returned as-is (no expansion).
|
||||
assert result == malicious
|
||||
# Verify the injected command did NOT execute
|
||||
import os
|
||||
assert not os.path.exists("/tmp/_hermes_injection_test")
|
||||
|
||||
def test_tilde_username_with_subpath(self, ops):
|
||||
"""~root/file.txt should attempt expansion (valid username)."""
|
||||
result = ops._expand_path("~root/file.txt")
|
||||
# On most systems ~root expands to /root
|
||||
if result != "~root/file.txt":
|
||||
assert result.endswith("/file.txt")
|
||||
assert "~" not in result
|
||||
|
||||
|
||||
# ── Terminal output cleanliness ──────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
"""Tests for tools/vision_tools.py — URL validation, type hints, error logging."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Awaitable
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.vision_tools import (
|
||||
_validate_image_url,
|
||||
_handle_vision_analyze,
|
||||
_determine_mime_type,
|
||||
_image_to_base64_data_url,
|
||||
vision_analyze_tool,
|
||||
check_vision_requirements,
|
||||
get_debug_session_info,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_image_url — urlparse-based validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateImageUrl:
|
||||
"""Tests for URL validation, including urlparse-based netloc check."""
|
||||
|
||||
def test_valid_https_url(self):
|
||||
assert _validate_image_url("https://example.com/image.jpg") is True
|
||||
|
||||
def test_valid_http_url(self):
|
||||
assert _validate_image_url("http://cdn.example.org/photo.png") is True
|
||||
|
||||
def test_valid_url_without_extension(self):
|
||||
"""CDN endpoints that redirect to images should still pass."""
|
||||
assert _validate_image_url("https://cdn.example.com/abcdef123") is True
|
||||
|
||||
def test_valid_url_with_query_params(self):
|
||||
assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
|
||||
|
||||
def test_valid_url_with_port(self):
|
||||
assert _validate_image_url("http://localhost:8080/image.png") is True
|
||||
|
||||
def test_valid_url_with_path_only(self):
|
||||
assert _validate_image_url("https://example.com/") is True
|
||||
|
||||
def test_rejects_empty_string(self):
|
||||
assert _validate_image_url("") is False
|
||||
|
||||
def test_rejects_none(self):
|
||||
assert _validate_image_url(None) is False
|
||||
|
||||
def test_rejects_non_string(self):
|
||||
assert _validate_image_url(12345) is False
|
||||
|
||||
def test_rejects_ftp_scheme(self):
|
||||
assert _validate_image_url("ftp://files.example.com/image.jpg") is False
|
||||
|
||||
def test_rejects_file_scheme(self):
|
||||
assert _validate_image_url("file:///etc/passwd") is False
|
||||
|
||||
def test_rejects_no_scheme(self):
|
||||
assert _validate_image_url("example.com/image.jpg") is False
|
||||
|
||||
def test_rejects_javascript_scheme(self):
|
||||
assert _validate_image_url("javascript:alert(1)") is False
|
||||
|
||||
def test_rejects_http_without_netloc(self):
|
||||
"""http:// alone has no network location — urlparse catches this."""
|
||||
assert _validate_image_url("http://") is False
|
||||
|
||||
def test_rejects_https_without_netloc(self):
|
||||
assert _validate_image_url("https://") is False
|
||||
|
||||
def test_rejects_http_colon_only(self):
|
||||
assert _validate_image_url("http:") is False
|
||||
|
||||
def test_rejects_data_url(self):
|
||||
assert _validate_image_url("data:image/png;base64,iVBOR") is False
|
||||
|
||||
def test_rejects_whitespace_only(self):
|
||||
assert _validate_image_url(" ") is False
|
||||
|
||||
def test_rejects_boolean(self):
|
||||
assert _validate_image_url(True) is False
|
||||
|
||||
def test_rejects_list(self):
|
||||
assert _validate_image_url(["https://example.com"]) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _determine_mime_type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDetermineMimeType:
|
||||
def test_jpg(self):
|
||||
assert _determine_mime_type(Path("photo.jpg")) == "image/jpeg"
|
||||
|
||||
def test_jpeg(self):
|
||||
assert _determine_mime_type(Path("photo.jpeg")) == "image/jpeg"
|
||||
|
||||
def test_png(self):
|
||||
assert _determine_mime_type(Path("screenshot.png")) == "image/png"
|
||||
|
||||
def test_gif(self):
|
||||
assert _determine_mime_type(Path("anim.gif")) == "image/gif"
|
||||
|
||||
def test_webp(self):
|
||||
assert _determine_mime_type(Path("modern.webp")) == "image/webp"
|
||||
|
||||
def test_unknown_extension_defaults_to_jpeg(self):
|
||||
assert _determine_mime_type(Path("file.xyz")) == "image/jpeg"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _image_to_base64_data_url
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestImageToBase64DataUrl:
|
||||
def test_returns_data_url(self, tmp_path):
|
||||
img = tmp_path / "test.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
|
||||
result = _image_to_base64_data_url(img)
|
||||
assert result.startswith("data:image/png;base64,")
|
||||
|
||||
def test_custom_mime_type(self, tmp_path):
|
||||
img = tmp_path / "test.bin"
|
||||
img.write_bytes(b"\x00" * 16)
|
||||
result = _image_to_base64_data_url(img, mime_type="image/webp")
|
||||
assert result.startswith("data:image/webp;base64,")
|
||||
|
||||
def test_file_not_found_raises(self, tmp_path):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
_image_to_base64_data_url(tmp_path / "nonexistent.png")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_vision_analyze — type signature & behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHandleVisionAnalyze:
|
||||
"""Verify _handle_vision_analyze returns an Awaitable and builds correct prompt."""
|
||||
|
||||
def test_returns_awaitable(self):
|
||||
"""The handler must return an Awaitable (coroutine) since it's registered as async."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
result = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "What is this?"}
|
||||
)
|
||||
# It should be an Awaitable (coroutine)
|
||||
assert isinstance(result, Awaitable)
|
||||
# Clean up the coroutine to avoid RuntimeWarning
|
||||
result.close()
|
||||
|
||||
def test_prompt_contains_question(self):
|
||||
"""The full prompt should incorporate the user's question."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "Describe the cat"}
|
||||
)
|
||||
# Clean up coroutine
|
||||
coro.close()
|
||||
call_args = mock_tool.call_args
|
||||
full_prompt = call_args[0][1] # second positional arg
|
||||
assert "Describe the cat" in full_prompt
|
||||
assert "Fully describe and explain" in full_prompt
|
||||
|
||||
def test_uses_auxiliary_vision_model_env(self):
|
||||
"""AUXILIARY_VISION_MODEL env var should override DEFAULT_VISION_MODEL."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
|
||||
patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}):
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "test"}
|
||||
)
|
||||
coro.close()
|
||||
call_args = mock_tool.call_args
|
||||
model = call_args[0][2] # third positional arg
|
||||
assert model == "custom/model-v1"
|
||||
|
||||
def test_falls_back_to_default_model(self):
|
||||
"""Without AUXILIARY_VISION_MODEL, should use DEFAULT_VISION_MODEL or fallback."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
|
||||
patch.dict(os.environ, {}, clear=False):
|
||||
# Ensure AUXILIARY_VISION_MODEL is not set
|
||||
os.environ.pop("AUXILIARY_VISION_MODEL", None)
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "test"}
|
||||
)
|
||||
coro.close()
|
||||
call_args = mock_tool.call_args
|
||||
model = call_args[0][2]
|
||||
# Should be DEFAULT_VISION_MODEL or the hardcoded fallback
|
||||
assert model is not None
|
||||
assert len(model) > 0
|
||||
|
||||
def test_empty_args_graceful(self):
|
||||
"""Missing keys should default to empty strings, not raise."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
result = _handle_vision_analyze({})
|
||||
assert isinstance(result, Awaitable)
|
||||
result.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error logging with exc_info — verify tracebacks are logged
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestErrorLoggingExcInfo:
|
||||
"""Verify that exc_info=True is used in error/warning log calls."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_failure_logs_exc_info(self, tmp_path, caplog):
|
||||
"""After max retries, the download error should include exc_info."""
|
||||
from tools.vision_tools import _download_image
|
||||
|
||||
with patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(side_effect=ConnectionError("network down"))
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
dest = tmp_path / "image.jpg"
|
||||
with caplog.at_level(logging.ERROR, logger="tools.vision_tools"), \
|
||||
pytest.raises(ConnectionError):
|
||||
await _download_image("https://example.com/img.jpg", dest, max_retries=1)
|
||||
|
||||
# Should have logged with exc_info (traceback present)
|
||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||
assert len(error_records) >= 1
|
||||
assert error_records[0].exc_info is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analysis_error_logs_exc_info(self, caplog):
|
||||
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
|
||||
with patch("tools.vision_tools._validate_image_url", return_value=True), \
|
||||
patch("tools.vision_tools._download_image", new_callable=AsyncMock,
|
||||
side_effect=Exception("download boom")), \
|
||||
caplog.at_level(logging.ERROR, logger="tools.vision_tools"):
|
||||
|
||||
result = await vision_analyze_tool(
|
||||
"https://example.com/img.jpg", "describe this", "test/model"
|
||||
)
|
||||
result_data = json.loads(result)
|
||||
# Error response uses "success": False, not an "error" key
|
||||
assert result_data["success"] is False
|
||||
|
||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||
assert any(r.exc_info is not None for r in error_records)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_error_logs_exc_info(self, tmp_path, caplog):
|
||||
"""Temp file cleanup failure should log warning with exc_info."""
|
||||
# Create a real temp file that will be "downloaded"
|
||||
temp_dir = tmp_path / "temp_vision_images"
|
||||
temp_dir.mkdir()
|
||||
|
||||
async def fake_download(url, dest, max_retries=3):
|
||||
"""Simulate download by writing file to the expected destination."""
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16)
|
||||
return dest
|
||||
|
||||
with patch("tools.vision_tools._validate_image_url", return_value=True), \
|
||||
patch("tools.vision_tools._download_image", side_effect=fake_download), \
|
||||
patch("tools.vision_tools._image_to_base64_data_url",
|
||||
return_value="data:image/jpeg;base64,abc"), \
|
||||
patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None), \
|
||||
patch("agent.auxiliary_client.auxiliary_max_tokens_param", return_value={"max_tokens": 2000}), \
|
||||
caplog.at_level(logging.WARNING, logger="tools.vision_tools"):
|
||||
|
||||
# Mock the vision client
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "A test image description"
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch module-level _aux_async_client so the tool doesn't bail early
|
||||
with patch("tools.vision_tools._aux_async_client", mock_client), \
|
||||
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"):
|
||||
|
||||
# Make unlink fail to trigger cleanup warning
|
||||
original_unlink = Path.unlink
|
||||
def failing_unlink(self, *args, **kwargs):
|
||||
raise PermissionError("no permission")
|
||||
|
||||
with patch.object(Path, "unlink", failing_unlink):
|
||||
result = await vision_analyze_tool(
|
||||
"https://example.com/tempimg.jpg", "describe", "test/model"
|
||||
)
|
||||
|
||||
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING
|
||||
and "temporary file" in r.getMessage().lower()]
|
||||
assert len(warning_records) >= 1
|
||||
assert warning_records[0].exc_info is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_vision_requirements & get_debug_session_info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestVisionRequirements:
|
||||
def test_check_requirements_returns_bool(self):
|
||||
result = check_vision_requirements()
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_debug_session_info_returns_dict(self):
|
||||
info = get_debug_session_info()
|
||||
assert isinstance(info, dict)
|
||||
# DebugSession.get_session_info() returns these keys
|
||||
assert "enabled" in info
|
||||
assert "session_id" in info
|
||||
assert "total_calls" in info
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: registry entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestVisionRegistration:
|
||||
def test_vision_analyze_registered(self):
|
||||
from tools.registry import registry
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
assert entry is not None
|
||||
assert entry.toolset == "vision"
|
||||
assert entry.is_async is True
|
||||
|
||||
def test_schema_has_required_fields(self):
|
||||
from tools.registry import registry
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
schema = entry.schema
|
||||
assert schema["name"] == "vision_analyze"
|
||||
params = schema.get("parameters", {})
|
||||
props = params.get("properties", {})
|
||||
assert "image_url" in props
|
||||
assert "question" in props
|
||||
|
||||
def test_handler_is_callable(self):
|
||||
from tools.registry import registry
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
assert callable(entry.handler)
|
||||
@@ -16,222 +16,249 @@ for the AI agent to access all capabilities.
|
||||
"""
|
||||
|
||||
# Export all tools for easy importing
|
||||
from .web_tools import (
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_crawl_tool,
|
||||
check_firecrawl_api_key
|
||||
)
|
||||
|
||||
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona)
|
||||
from .terminal_tool import (
|
||||
terminal_tool,
|
||||
check_terminal_requirements,
|
||||
cleanup_vm,
|
||||
cleanup_all_environments,
|
||||
get_active_environments_info,
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
TERMINAL_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
from .vision_tools import (
|
||||
vision_analyze_tool,
|
||||
check_vision_requirements
|
||||
)
|
||||
|
||||
from .mixture_of_agents_tool import (
|
||||
mixture_of_agents_tool,
|
||||
check_moa_requirements
|
||||
)
|
||||
|
||||
from .image_generation_tool import (
|
||||
image_generate_tool,
|
||||
check_image_generation_requirements
|
||||
)
|
||||
|
||||
from .skills_tool import (
|
||||
skills_list,
|
||||
skill_view,
|
||||
check_skills_requirements,
|
||||
SKILLS_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
from .skill_manager_tool import (
|
||||
skill_manage,
|
||||
check_skill_manage_requirements,
|
||||
SKILL_MANAGE_SCHEMA
|
||||
)
|
||||
|
||||
# Browser automation tools (agent-browser + Browserbase)
|
||||
from .browser_tool import (
|
||||
BROWSER_TOOL_SCHEMAS,
|
||||
browser_back,
|
||||
browser_navigate,
|
||||
browser_snapshot,
|
||||
browser_click,
|
||||
browser_type,
|
||||
browser_scroll,
|
||||
browser_back,
|
||||
browser_press,
|
||||
browser_close,
|
||||
browser_get_images,
|
||||
browser_navigate,
|
||||
browser_press,
|
||||
browser_scroll,
|
||||
browser_snapshot,
|
||||
browser_type,
|
||||
browser_vision,
|
||||
check_browser_requirements,
|
||||
cleanup_all_browsers,
|
||||
cleanup_browser,
|
||||
cleanup_all_browsers,
|
||||
get_active_browser_sessions,
|
||||
)
|
||||
|
||||
# Clarifying questions tool (interactive Q&A with the user)
|
||||
from .clarify_tool import (
|
||||
CLARIFY_SCHEMA,
|
||||
check_clarify_requirements,
|
||||
clarify_tool,
|
||||
)
|
||||
|
||||
# Code execution sandbox (programmatic tool calling)
|
||||
from .code_execution_tool import (
|
||||
EXECUTE_CODE_SCHEMA,
|
||||
check_sandbox_requirements,
|
||||
execute_code,
|
||||
check_browser_requirements,
|
||||
BROWSER_TOOL_SCHEMAS
|
||||
)
|
||||
|
||||
# Cronjob management tools (CLI-only, hermes-cli toolset)
|
||||
from .cronjob_tools import (
|
||||
LIST_CRONJOBS_SCHEMA,
|
||||
REMOVE_CRONJOB_SCHEMA,
|
||||
SCHEDULE_CRONJOB_SCHEMA,
|
||||
check_cronjob_requirements,
|
||||
get_cronjob_tool_definitions,
|
||||
schedule_cronjob,
|
||||
list_cronjobs,
|
||||
remove_cronjob,
|
||||
schedule_cronjob,
|
||||
check_cronjob_requirements,
|
||||
get_cronjob_tool_definitions,
|
||||
SCHEDULE_CRONJOB_SCHEMA,
|
||||
LIST_CRONJOBS_SCHEMA,
|
||||
REMOVE_CRONJOB_SCHEMA
|
||||
)
|
||||
|
||||
# Subagent delegation (spawn child agents with isolated context)
|
||||
from .delegate_tool import (
|
||||
DELEGATE_TASK_SCHEMA,
|
||||
check_delegate_requirements,
|
||||
delegate_task,
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from .rl_training_tool import (
|
||||
rl_list_environments,
|
||||
rl_select_environment,
|
||||
rl_get_current_config,
|
||||
rl_edit_config,
|
||||
rl_start_training,
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_list_runs,
|
||||
rl_test_inference,
|
||||
check_rl_api_keys,
|
||||
get_missing_keys,
|
||||
)
|
||||
|
||||
# File manipulation tools (read, write, patch, search)
|
||||
from .file_tools import (
|
||||
clear_file_ops_cache,
|
||||
get_file_tools,
|
||||
patch_tool,
|
||||
read_file_tool,
|
||||
search_tool,
|
||||
write_file_tool,
|
||||
)
|
||||
from .image_generation_tool import check_image_generation_requirements, image_generate_tool
|
||||
from .mixture_of_agents_tool import check_moa_requirements, mixture_of_agents_tool
|
||||
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from .rl_training_tool import (
|
||||
check_rl_api_keys,
|
||||
get_missing_keys,
|
||||
rl_check_status,
|
||||
rl_edit_config,
|
||||
rl_get_current_config,
|
||||
rl_get_results,
|
||||
rl_list_environments,
|
||||
rl_list_runs,
|
||||
rl_select_environment,
|
||||
rl_start_training,
|
||||
rl_stop_training,
|
||||
rl_test_inference,
|
||||
)
|
||||
from .skill_manager_tool import SKILL_MANAGE_SCHEMA, check_skill_manage_requirements, skill_manage
|
||||
from .skills_tool import SKILLS_TOOL_DESCRIPTION, check_skills_requirements, skill_view, skills_list
|
||||
|
||||
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona)
|
||||
from .terminal_tool import (
|
||||
TERMINAL_TOOL_DESCRIPTION,
|
||||
check_terminal_requirements,
|
||||
cleanup_all_environments,
|
||||
cleanup_vm,
|
||||
clear_task_env_overrides,
|
||||
get_active_environments_info,
|
||||
register_task_env_overrides,
|
||||
terminal_tool,
|
||||
)
|
||||
|
||||
# Planning & task management tool
|
||||
from .todo_tool import (
|
||||
TODO_SCHEMA,
|
||||
TodoStore,
|
||||
check_todo_requirements,
|
||||
todo_tool,
|
||||
patch_tool,
|
||||
search_tool,
|
||||
get_file_tools,
|
||||
clear_file_ops_cache,
|
||||
)
|
||||
|
||||
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
|
||||
from .tts_tool import (
|
||||
check_tts_requirements,
|
||||
text_to_speech_tool,
|
||||
check_tts_requirements,
|
||||
)
|
||||
from .vision_tools import check_vision_requirements, vision_analyze_tool
|
||||
from .web_tools import check_firecrawl_api_key, web_crawl_tool, web_extract_tool, web_search_tool
|
||||
|
||||
# Planning & task management tool
|
||||
from .todo_tool import (
|
||||
todo_tool,
|
||||
check_todo_requirements,
|
||||
TODO_SCHEMA,
|
||||
TodoStore,
|
||||
)
|
||||
|
||||
# Clarifying questions tool (interactive Q&A with the user)
|
||||
from .clarify_tool import (
|
||||
clarify_tool,
|
||||
check_clarify_requirements,
|
||||
CLARIFY_SCHEMA,
|
||||
)
|
||||
|
||||
# Code execution sandbox (programmatic tool calling)
|
||||
from .code_execution_tool import (
|
||||
execute_code,
|
||||
check_sandbox_requirements,
|
||||
EXECUTE_CODE_SCHEMA,
|
||||
)
|
||||
|
||||
# Subagent delegation (spawn child agents with isolated context)
|
||||
from .delegate_tool import (
|
||||
delegate_task,
|
||||
check_delegate_requirements,
|
||||
DELEGATE_TASK_SCHEMA,
|
||||
)
|
||||
|
||||
# File tools have no external requirements - they use the terminal backend
|
||||
def check_file_requirements():
|
||||
"""File tools only require terminal backend to be available."""
|
||||
from .terminal_tool import check_terminal_requirements
|
||||
|
||||
return check_terminal_requirements()
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Web tools
|
||||
"web_search_tool",
|
||||
"web_extract_tool",
|
||||
"web_crawl_tool",
|
||||
"check_firecrawl_api_key",
|
||||
'web_search_tool',
|
||||
'web_extract_tool',
|
||||
'web_crawl_tool',
|
||||
'check_firecrawl_api_key',
|
||||
# Terminal tools (mini-swe-agent backend)
|
||||
"terminal_tool",
|
||||
"check_terminal_requirements",
|
||||
"cleanup_vm",
|
||||
"cleanup_all_environments",
|
||||
"get_active_environments_info",
|
||||
"register_task_env_overrides",
|
||||
"clear_task_env_overrides",
|
||||
"TERMINAL_TOOL_DESCRIPTION",
|
||||
'terminal_tool',
|
||||
'check_terminal_requirements',
|
||||
'cleanup_vm',
|
||||
'cleanup_all_environments',
|
||||
'get_active_environments_info',
|
||||
'register_task_env_overrides',
|
||||
'clear_task_env_overrides',
|
||||
'TERMINAL_TOOL_DESCRIPTION',
|
||||
# Vision tools
|
||||
"vision_analyze_tool",
|
||||
"check_vision_requirements",
|
||||
'vision_analyze_tool',
|
||||
'check_vision_requirements',
|
||||
# MoA tools
|
||||
"mixture_of_agents_tool",
|
||||
"check_moa_requirements",
|
||||
'mixture_of_agents_tool',
|
||||
'check_moa_requirements',
|
||||
# Image generation tools
|
||||
"image_generate_tool",
|
||||
"check_image_generation_requirements",
|
||||
'image_generate_tool',
|
||||
'check_image_generation_requirements',
|
||||
# Skills tools
|
||||
"skills_list",
|
||||
"skill_view",
|
||||
"check_skills_requirements",
|
||||
"SKILLS_TOOL_DESCRIPTION",
|
||||
'skills_list',
|
||||
'skill_view',
|
||||
'check_skills_requirements',
|
||||
'SKILLS_TOOL_DESCRIPTION',
|
||||
# Skill management
|
||||
"skill_manage",
|
||||
"check_skill_manage_requirements",
|
||||
"SKILL_MANAGE_SCHEMA",
|
||||
'skill_manage',
|
||||
'check_skill_manage_requirements',
|
||||
'SKILL_MANAGE_SCHEMA',
|
||||
# Browser automation tools
|
||||
"browser_navigate",
|
||||
"browser_snapshot",
|
||||
"browser_click",
|
||||
"browser_type",
|
||||
"browser_scroll",
|
||||
"browser_back",
|
||||
"browser_press",
|
||||
"browser_close",
|
||||
"browser_get_images",
|
||||
"browser_vision",
|
||||
"cleanup_browser",
|
||||
"cleanup_all_browsers",
|
||||
"get_active_browser_sessions",
|
||||
"check_browser_requirements",
|
||||
"BROWSER_TOOL_SCHEMAS",
|
||||
'browser_navigate',
|
||||
'browser_snapshot',
|
||||
'browser_click',
|
||||
'browser_type',
|
||||
'browser_scroll',
|
||||
'browser_back',
|
||||
'browser_press',
|
||||
'browser_close',
|
||||
'browser_get_images',
|
||||
'browser_vision',
|
||||
'cleanup_browser',
|
||||
'cleanup_all_browsers',
|
||||
'get_active_browser_sessions',
|
||||
'check_browser_requirements',
|
||||
'BROWSER_TOOL_SCHEMAS',
|
||||
# Cronjob management tools (CLI-only)
|
||||
"schedule_cronjob",
|
||||
"list_cronjobs",
|
||||
"remove_cronjob",
|
||||
"check_cronjob_requirements",
|
||||
"get_cronjob_tool_definitions",
|
||||
"SCHEDULE_CRONJOB_SCHEMA",
|
||||
"LIST_CRONJOBS_SCHEMA",
|
||||
"REMOVE_CRONJOB_SCHEMA",
|
||||
'schedule_cronjob',
|
||||
'list_cronjobs',
|
||||
'remove_cronjob',
|
||||
'check_cronjob_requirements',
|
||||
'get_cronjob_tool_definitions',
|
||||
'SCHEDULE_CRONJOB_SCHEMA',
|
||||
'LIST_CRONJOBS_SCHEMA',
|
||||
'REMOVE_CRONJOB_SCHEMA',
|
||||
# RL Training tools
|
||||
"rl_list_environments",
|
||||
"rl_select_environment",
|
||||
"rl_get_current_config",
|
||||
"rl_edit_config",
|
||||
"rl_start_training",
|
||||
"rl_check_status",
|
||||
"rl_stop_training",
|
||||
"rl_get_results",
|
||||
"rl_list_runs",
|
||||
"rl_test_inference",
|
||||
"check_rl_api_keys",
|
||||
"get_missing_keys",
|
||||
'rl_list_environments',
|
||||
'rl_select_environment',
|
||||
'rl_get_current_config',
|
||||
'rl_edit_config',
|
||||
'rl_start_training',
|
||||
'rl_check_status',
|
||||
'rl_stop_training',
|
||||
'rl_get_results',
|
||||
'rl_list_runs',
|
||||
'rl_test_inference',
|
||||
'check_rl_api_keys',
|
||||
'get_missing_keys',
|
||||
# File manipulation tools
|
||||
"read_file_tool",
|
||||
"write_file_tool",
|
||||
"patch_tool",
|
||||
"search_tool",
|
||||
"get_file_tools",
|
||||
"clear_file_ops_cache",
|
||||
"check_file_requirements",
|
||||
'read_file_tool',
|
||||
'write_file_tool',
|
||||
'patch_tool',
|
||||
'search_tool',
|
||||
'get_file_tools',
|
||||
'clear_file_ops_cache',
|
||||
'check_file_requirements',
|
||||
# Text-to-speech tools
|
||||
"text_to_speech_tool",
|
||||
"check_tts_requirements",
|
||||
'text_to_speech_tool',
|
||||
'check_tts_requirements',
|
||||
# Planning & task management tool
|
||||
"todo_tool",
|
||||
"check_todo_requirements",
|
||||
"TODO_SCHEMA",
|
||||
"TodoStore",
|
||||
'todo_tool',
|
||||
'check_todo_requirements',
|
||||
'TODO_SCHEMA',
|
||||
'TodoStore',
|
||||
# Clarifying questions tool
|
||||
"clarify_tool",
|
||||
"check_clarify_requirements",
|
||||
"CLARIFY_SCHEMA",
|
||||
'clarify_tool',
|
||||
'check_clarify_requirements',
|
||||
'CLARIFY_SCHEMA',
|
||||
# Code execution sandbox
|
||||
"execute_code",
|
||||
"check_sandbox_requirements",
|
||||
"EXECUTE_CODE_SCHEMA",
|
||||
'execute_code',
|
||||
'check_sandbox_requirements',
|
||||
'EXECUTE_CODE_SCHEMA',
|
||||
# Subagent delegation
|
||||
"delegate_task",
|
||||
"check_delegate_requirements",
|
||||
"DELEGATE_TASK_SCHEMA",
|
||||
'delegate_task',
|
||||
'check_delegate_requirements',
|
||||
'DELEGATE_TASK_SCHEMA',
|
||||
]
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,32 +21,32 @@ logger = logging.getLogger(__name__)
|
||||
# =========================================================================
|
||||
|
||||
DANGEROUS_PATTERNS = [
|
||||
(r"\brm\s+(-[^\s]*\s+)*/", "delete in root path"),
|
||||
(r"\brm\s+-[^\s]*r", "recursive delete"),
|
||||
(r"\brm\s+--recursive\b", "recursive delete (long flag)"),
|
||||
(r"\bchmod\s+(-[^\s]*\s+)*777\b", "world-writable permissions"),
|
||||
(r"\bchmod\s+--recursive\b.*777", "recursive world-writable (long flag)"),
|
||||
(r"\bchown\s+(-[^\s]*)?R\s+root", "recursive chown to root"),
|
||||
(r"\bchown\s+--recursive\b.*root", "recursive chown to root (long flag)"),
|
||||
(r"\bmkfs\b", "format filesystem"),
|
||||
(r"\bdd\s+.*if=", "disk copy"),
|
||||
(r">\s*/dev/sd", "write to block device"),
|
||||
(r"\bDROP\s+(TABLE|DATABASE)\b", "SQL DROP"),
|
||||
(r"\bDELETE\s+FROM\b(?!.*\bWHERE\b)", "SQL DELETE without WHERE"),
|
||||
(r"\bTRUNCATE\s+(TABLE)?\s*\w", "SQL TRUNCATE"),
|
||||
(r">\s*/etc/", "overwrite system config"),
|
||||
(r"\bsystemctl\s+(stop|disable|mask)\b", "stop/disable system service"),
|
||||
(r"\bkill\s+-9\s+-1\b", "kill all processes"),
|
||||
(r"\bpkill\s+-9\b", "force kill processes"),
|
||||
(r":()\s*{\s*:\s*\|\s*:&\s*}\s*;:", "fork bomb"),
|
||||
(r"\b(bash|sh|zsh)\s+-c\s+", "shell command via -c flag"),
|
||||
(r"\b(python[23]?|perl|ruby|node)\s+-[ec]\s+", "script execution via -e/-c flag"),
|
||||
(r"\b(curl|wget)\b.*\|\s*(ba)?sh\b", "pipe remote content to shell"),
|
||||
(r"\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b", "execute remote script via process substitution"),
|
||||
(r"\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)", "overwrite system file via tee"),
|
||||
(r"\bxargs\s+.*\brm\b", "xargs with rm"),
|
||||
(r"\bfind\b.*-exec\s+(/\S*/)?rm\b", "find -exec rm"),
|
||||
(r"\bfind\b.*-delete\b", "find -delete"),
|
||||
(r'\brm\s+(-[^\s]*\s+)*/', "delete in root path"),
|
||||
(r'\brm\s+-[^\s]*r', "recursive delete"),
|
||||
(r'\brm\s+--recursive\b', "recursive delete (long flag)"),
|
||||
(r'\bchmod\s+(-[^\s]*\s+)*777\b', "world-writable permissions"),
|
||||
(r'\bchmod\s+--recursive\b.*777', "recursive world-writable (long flag)"),
|
||||
(r'\bchown\s+(-[^\s]*)?R\s+root', "recursive chown to root"),
|
||||
(r'\bchown\s+--recursive\b.*root', "recursive chown to root (long flag)"),
|
||||
(r'\bmkfs\b', "format filesystem"),
|
||||
(r'\bdd\s+.*if=', "disk copy"),
|
||||
(r'>\s*/dev/sd', "write to block device"),
|
||||
(r'\bDROP\s+(TABLE|DATABASE)\b', "SQL DROP"),
|
||||
(r'\bDELETE\s+FROM\b(?!.*\bWHERE\b)', "SQL DELETE without WHERE"),
|
||||
(r'\bTRUNCATE\s+(TABLE)?\s*\w', "SQL TRUNCATE"),
|
||||
(r'>\s*/etc/', "overwrite system config"),
|
||||
(r'\bsystemctl\s+(stop|disable|mask)\b', "stop/disable system service"),
|
||||
(r'\bkill\s+-9\s+-1\b', "kill all processes"),
|
||||
(r'\bpkill\s+-9\b', "force kill processes"),
|
||||
(r':()\s*{\s*:\s*\|\s*:&\s*}\s*;:', "fork bomb"),
|
||||
(r'\b(bash|sh|zsh)\s+-c\s+', "shell command via -c flag"),
|
||||
(r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"),
|
||||
(r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"),
|
||||
(r'\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b', "execute remote script via process substitution"),
|
||||
(r'\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)', "overwrite system file via tee"),
|
||||
(r'\bxargs\s+.*\brm\b', "xargs with rm"),
|
||||
(r'\bfind\b.*-exec\s+(/\S*/)?rm\b', "find -exec rm"),
|
||||
(r'\bfind\b.*-delete\b', "find -delete"),
|
||||
]
|
||||
|
||||
|
||||
@@ -53,7 +54,6 @@ DANGEROUS_PATTERNS = [
|
||||
# Detection
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def detect_dangerous_command(command: str) -> tuple:
|
||||
"""Check if a command matches any dangerous patterns.
|
||||
|
||||
@@ -63,7 +63,7 @@ def detect_dangerous_command(command: str) -> tuple:
|
||||
command_lower = command.lower()
|
||||
for pattern, description in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, command_lower, re.IGNORECASE | re.DOTALL):
|
||||
pattern_key = pattern.split(r"\b")[1] if r"\b" in pattern else pattern[:20]
|
||||
pattern_key = pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20]
|
||||
return (True, pattern_key, description)
|
||||
return (False, None, None)
|
||||
|
||||
@@ -84,7 +84,7 @@ def submit_pending(session_key: str, approval: dict):
|
||||
_pending[session_key] = approval
|
||||
|
||||
|
||||
def pop_pending(session_key: str) -> dict | None:
|
||||
def pop_pending(session_key: str) -> Optional[dict]:
|
||||
"""Retrieve and remove a pending approval for a session."""
|
||||
with _lock:
|
||||
return _pending.pop(session_key, None)
|
||||
@@ -133,7 +133,6 @@ def clear_session(session_key: str):
|
||||
# Config persistence for permanent allowlist
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def load_permanent_allowlist() -> set:
|
||||
"""Load permanently allowed command patterns from config.
|
||||
|
||||
@@ -142,7 +141,6 @@ def load_permanent_allowlist() -> set:
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
patterns = set(config.get("command_allowlist", []) or [])
|
||||
if patterns:
|
||||
@@ -156,7 +154,6 @@ def save_permanent_allowlist(patterns: set):
|
||||
"""Save permanently allowed command patterns to config."""
|
||||
try:
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
config = load_config()
|
||||
config["command_allowlist"] = list(patterns)
|
||||
save_config(config)
|
||||
@@ -168,8 +165,9 @@ def save_permanent_allowlist(patterns: set):
|
||||
# Approval prompting + orchestration
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def prompt_dangerous_approval(command: str, description: str, timeout_seconds: int = 60, approval_callback=None) -> str:
|
||||
def prompt_dangerous_approval(command: str, description: str,
|
||||
timeout_seconds: int = 60,
|
||||
approval_callback=None) -> str:
|
||||
"""Prompt the user to approve a dangerous command (CLI only).
|
||||
|
||||
Args:
|
||||
@@ -190,7 +188,7 @@ def prompt_dangerous_approval(command: str, description: str, timeout_seconds: i
|
||||
print(f" ⚠️ DANGEROUS COMMAND: {description}")
|
||||
print(f" {command[:80]}{'...' if len(command) > 80 else ''}")
|
||||
print()
|
||||
print(" [o]nce | [s]ession | [a]lways | [d]eny")
|
||||
print(f" [o]nce | [s]ession | [a]lways | [d]eny")
|
||||
print()
|
||||
sys.stdout.flush()
|
||||
|
||||
@@ -211,13 +209,13 @@ def prompt_dangerous_approval(command: str, description: str, timeout_seconds: i
|
||||
return "deny"
|
||||
|
||||
choice = result["choice"]
|
||||
if choice in ("o", "once"):
|
||||
if choice in ('o', 'once'):
|
||||
print(" ✓ Allowed once")
|
||||
return "once"
|
||||
elif choice in ("s", "session"):
|
||||
elif choice in ('s', 'session'):
|
||||
print(" ✓ Allowed for this session")
|
||||
return "session"
|
||||
elif choice in ("a", "always"):
|
||||
elif choice in ('a', 'always'):
|
||||
print(" ✓ Added to permanent allowlist")
|
||||
return "always"
|
||||
else:
|
||||
@@ -234,7 +232,8 @@ def prompt_dangerous_approval(command: str, description: str, timeout_seconds: i
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def check_dangerous_command(command: str, env_type: str, approval_callback=None) -> dict:
|
||||
def check_dangerous_command(command: str, env_type: str,
|
||||
approval_callback=None) -> dict:
|
||||
"""Check if a command is dangerous and handle approval.
|
||||
|
||||
This is the main entry point called by terminal_tool before executing
|
||||
@@ -266,14 +265,11 @@ def check_dangerous_command(command: str, env_type: str, approval_callback=None)
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
if is_gateway or os.getenv("HERMES_EXEC_ASK"):
|
||||
submit_pending(
|
||||
session_key,
|
||||
{
|
||||
"command": command,
|
||||
"pattern_key": pattern_key,
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
submit_pending(session_key, {
|
||||
"command": command,
|
||||
"pattern_key": pattern_key,
|
||||
"description": description,
|
||||
})
|
||||
return {
|
||||
"approved": False,
|
||||
"pattern_key": pattern_key,
|
||||
@@ -283,7 +279,8 @@ def check_dangerous_command(command: str, env_type: str, approval_callback=None)
|
||||
"message": f"⚠️ This command is potentially dangerous ({description}). Asking the user for approval...",
|
||||
}
|
||||
|
||||
choice = prompt_dangerous_approval(command, description, approval_callback=approval_callback)
|
||||
choice = prompt_dangerous_approval(command, description,
|
||||
approval_callback=approval_callback)
|
||||
|
||||
if choice == "deny":
|
||||
return {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,8 @@ a thin dispatcher that delegates to a platform-provided callback.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
|
||||
|
||||
# Maximum number of predefined choices the agent can offer.
|
||||
# A 5th "Other (type your answer)" option is always appended by the UI.
|
||||
@@ -21,8 +22,8 @@ MAX_CHOICES = 4
|
||||
|
||||
def clarify_tool(
|
||||
question: str,
|
||||
choices: list[str] | None = None,
|
||||
callback: Callable | None = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Ask the user a question, optionally with multiple-choice options.
|
||||
@@ -67,14 +68,11 @@ def clarify_tool(
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"question": question,
|
||||
"choices_offered": choices,
|
||||
"user_response": str(user_response).strip(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
return json.dumps({
|
||||
"question": question,
|
||||
"choices_offered": choices,
|
||||
"user_response": str(user_response).strip(),
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_clarify_requirements() -> bool:
|
||||
@@ -135,7 +133,8 @@ registry.register(
|
||||
toolset="clarify",
|
||||
schema=CLARIFY_SCHEMA,
|
||||
handler=lambda args, **kw: clarify_tool(
|
||||
question=args.get("question", ""), choices=args.get("choices"), callback=kw.get("callback")
|
||||
),
|
||||
question=args.get("question", ""),
|
||||
choices=args.get("choices"),
|
||||
callback=kw.get("callback")),
|
||||
check_fn=check_clarify_requirements,
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ import time
|
||||
import uuid
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Availability gate: UDS requires a POSIX OS
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,23 +40,21 @@ SANDBOX_AVAILABLE = sys.platform != "win32"
|
||||
|
||||
# The 7 tools allowed inside the sandbox. The intersection of this list
|
||||
# and the session's enabled tools determines which stubs are generated.
|
||||
SANDBOX_ALLOWED_TOOLS = frozenset(
|
||||
[
|
||||
"web_search",
|
||||
"web_extract",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"search_files",
|
||||
"patch",
|
||||
"terminal",
|
||||
]
|
||||
)
|
||||
SANDBOX_ALLOWED_TOOLS = frozenset([
|
||||
"web_search",
|
||||
"web_extract",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"search_files",
|
||||
"patch",
|
||||
"terminal",
|
||||
])
|
||||
|
||||
# Resource limit defaults (overridable via config.yaml → code_execution.*)
|
||||
DEFAULT_TIMEOUT = 300 # 5 minutes
|
||||
DEFAULT_TIMEOUT = 300 # 5 minutes
|
||||
DEFAULT_MAX_TOOL_CALLS = 50
|
||||
MAX_STDOUT_BYTES = 50_000 # 50 KB
|
||||
MAX_STDERR_BYTES = 10_000 # 10 KB
|
||||
MAX_STDOUT_BYTES = 50_000 # 50 KB
|
||||
MAX_STDERR_BYTES = 10_000 # 10 KB
|
||||
|
||||
|
||||
def check_sandbox_requirements() -> bool:
|
||||
@@ -116,7 +114,7 @@ _TOOL_STUBS = {
|
||||
}
|
||||
|
||||
|
||||
def generate_hermes_tools_module(enabled_tools: list[str]) -> str:
|
||||
def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
|
||||
"""
|
||||
Build the source code for the hermes_tools.py stub module.
|
||||
|
||||
@@ -130,7 +128,11 @@ def generate_hermes_tools_module(enabled_tools: list[str]) -> str:
|
||||
if tool_name not in _TOOL_STUBS:
|
||||
continue
|
||||
func_name, sig, doc, args_expr = _TOOL_STUBS[tool_name]
|
||||
stub_functions.append(f"def {func_name}({sig}):\n {doc}\n return _call({func_name!r}, {args_expr})\n")
|
||||
stub_functions.append(
|
||||
f"def {func_name}({sig}):\n"
|
||||
f" {doc}\n"
|
||||
f" return _call({func_name!r}, {args_expr})\n"
|
||||
)
|
||||
export_names.append(func_name)
|
||||
|
||||
header = '''\
|
||||
@@ -221,7 +223,7 @@ def _rpc_server_loop(
|
||||
server_sock: socket.socket,
|
||||
task_id: str,
|
||||
tool_call_log: list,
|
||||
tool_call_counter: list, # mutable [int] so the thread can increment
|
||||
tool_call_counter: list, # mutable [int] so the thread can increment
|
||||
max_tool_calls: int,
|
||||
allowed_tools: frozenset,
|
||||
):
|
||||
@@ -241,7 +243,7 @@ def _rpc_server_loop(
|
||||
while True:
|
||||
try:
|
||||
chunk = conn.recv(65536)
|
||||
except TimeoutError:
|
||||
except socket.timeout:
|
||||
break
|
||||
if not chunk:
|
||||
break
|
||||
@@ -268,22 +270,23 @@ def _rpc_server_loop(
|
||||
# Enforce the allow-list
|
||||
if tool_name not in allowed_tools:
|
||||
available = ", ".join(sorted(allowed_tools))
|
||||
resp = json.dumps(
|
||||
{"error": (f"Tool '{tool_name}' is not available in execute_code. Available: {available}")}
|
||||
)
|
||||
resp = json.dumps({
|
||||
"error": (
|
||||
f"Tool '{tool_name}' is not available in execute_code. "
|
||||
f"Available: {available}"
|
||||
)
|
||||
})
|
||||
conn.sendall((resp + "\n").encode())
|
||||
continue
|
||||
|
||||
# Enforce tool call limit
|
||||
if tool_call_counter[0] >= max_tool_calls:
|
||||
resp = json.dumps(
|
||||
{
|
||||
"error": (
|
||||
f"Tool call limit reached ({max_tool_calls}). "
|
||||
"No more tool calls allowed in this execution."
|
||||
)
|
||||
}
|
||||
)
|
||||
resp = json.dumps({
|
||||
"error": (
|
||||
f"Tool call limit reached ({max_tool_calls}). "
|
||||
"No more tool calls allowed in this execution."
|
||||
)
|
||||
})
|
||||
conn.sendall((resp + "\n").encode())
|
||||
continue
|
||||
|
||||
@@ -300,7 +303,9 @@ def _rpc_server_loop(
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
try:
|
||||
result = handle_function_call(tool_name, tool_args, task_id=task_id)
|
||||
result = handle_function_call(
|
||||
tool_name, tool_args, task_id=task_id
|
||||
)
|
||||
finally:
|
||||
sys.stdout.close()
|
||||
sys.stderr.close()
|
||||
@@ -313,17 +318,15 @@ def _rpc_server_loop(
|
||||
|
||||
# Log for observability
|
||||
args_preview = str(tool_args)[:80]
|
||||
tool_call_log.append(
|
||||
{
|
||||
"tool": tool_name,
|
||||
"args_preview": args_preview,
|
||||
"duration": round(call_duration, 2),
|
||||
}
|
||||
)
|
||||
tool_call_log.append({
|
||||
"tool": tool_name,
|
||||
"args_preview": args_preview,
|
||||
"duration": round(call_duration, 2),
|
||||
})
|
||||
|
||||
conn.sendall((result + "\n").encode())
|
||||
|
||||
except TimeoutError:
|
||||
except socket.timeout:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
@@ -339,11 +342,10 @@ def _rpc_server_loop(
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def execute_code(
|
||||
code: str,
|
||||
task_id: str | None = None,
|
||||
enabled_tools: list[str] | None = None,
|
||||
task_id: Optional[str] = None,
|
||||
enabled_tools: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Run a Python script in a sandboxed child process with RPC access
|
||||
@@ -359,7 +361,9 @@ def execute_code(
|
||||
JSON string with execution results.
|
||||
"""
|
||||
if not SANDBOX_AVAILABLE:
|
||||
return json.dumps({"error": "execute_code is not available on Windows. Use normal tool calls instead."})
|
||||
return json.dumps({
|
||||
"error": "execute_code is not available on Windows. Use normal tool calls instead."
|
||||
})
|
||||
|
||||
if not code or not code.strip():
|
||||
return json.dumps({"error": "No code provided."})
|
||||
@@ -393,7 +397,9 @@ def execute_code(
|
||||
|
||||
try:
|
||||
# Write the auto-generated hermes_tools module
|
||||
tools_src = generate_hermes_tools_module(list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS))
|
||||
tools_src = generate_hermes_tools_module(
|
||||
list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS)
|
||||
)
|
||||
with open(os.path.join(tmpdir, "hermes_tools.py"), "w") as f:
|
||||
f.write(tools_src)
|
||||
|
||||
@@ -409,12 +415,8 @@ def execute_code(
|
||||
rpc_thread = threading.Thread(
|
||||
target=_rpc_server_loop,
|
||||
args=(
|
||||
server_sock,
|
||||
task_id,
|
||||
tool_call_log,
|
||||
tool_call_counter,
|
||||
max_tool_calls,
|
||||
sandbox_tools,
|
||||
server_sock, task_id, tool_call_log,
|
||||
tool_call_counter, max_tool_calls, sandbox_tools,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
@@ -424,24 +426,11 @@ def execute_code(
|
||||
# Build a minimal environment for the child. We intentionally exclude
|
||||
# API keys and tokens to prevent credential exfiltration from LLM-
|
||||
# generated scripts. The child accesses tools via RPC, not direct API.
|
||||
_SAFE_ENV_PREFIXES = (
|
||||
"PATH",
|
||||
"HOME",
|
||||
"USER",
|
||||
"LANG",
|
||||
"LC_",
|
||||
"TERM",
|
||||
"TMPDIR",
|
||||
"TMP",
|
||||
"TEMP",
|
||||
"SHELL",
|
||||
"LOGNAME",
|
||||
"XDG_",
|
||||
"PYTHONPATH",
|
||||
"VIRTUAL_ENV",
|
||||
"CONDA",
|
||||
)
|
||||
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL", "PASSWD", "AUTH")
|
||||
_SAFE_ENV_PREFIXES = ("PATH", "HOME", "USER", "LANG", "LC_", "TERM",
|
||||
"TMPDIR", "TMP", "TEMP", "SHELL", "LOGNAME",
|
||||
"XDG_", "PYTHONPATH", "VIRTUAL_ENV", "CONDA")
|
||||
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL",
|
||||
"PASSWD", "AUTH")
|
||||
child_env = {}
|
||||
for k, v in os.environ.items():
|
||||
if any(s in k.upper() for s in _SECRET_SUBSTRINGS):
|
||||
@@ -526,7 +515,7 @@ def execute_code(
|
||||
rpc_thread.join(timeout=3)
|
||||
|
||||
# Build response
|
||||
result: dict[str, Any] = {
|
||||
result: Dict[str, Any] = {
|
||||
"status": status,
|
||||
"output": stdout_text,
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
@@ -549,21 +538,17 @@ def execute_code(
|
||||
except Exception as exc:
|
||||
duration = round(time.monotonic() - exec_start, 2)
|
||||
logging.exception("execute_code failed")
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
"duration_seconds": duration,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
return json.dumps({
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
"duration_seconds": duration,
|
||||
}, ensure_ascii=False)
|
||||
|
||||
finally:
|
||||
# Cleanup temp dir and socket
|
||||
try:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
except Exception as e:
|
||||
logger.debug("Could not clean temp dir: %s", e)
|
||||
@@ -607,7 +592,6 @@ def _load_config() -> dict:
|
||||
"""Load code_execution config from CLI_CONFIG if available."""
|
||||
try:
|
||||
from cli import CLI_CONFIG
|
||||
|
||||
return CLI_CONFIG.get("code_execution", {})
|
||||
except Exception:
|
||||
return {}
|
||||
@@ -620,37 +604,27 @@ def _load_config() -> dict:
|
||||
# Per-tool documentation lines for the execute_code description.
|
||||
# Ordered to match the canonical display order.
|
||||
_TOOL_DOC_LINES = [
|
||||
(
|
||||
"web_search",
|
||||
" web_search(query: str, limit: int = 5) -> dict\n"
|
||||
' Returns {"data": {"web": [{"url", "title", "description"}, ...]}}',
|
||||
),
|
||||
(
|
||||
"web_extract",
|
||||
" web_extract(urls: list[str]) -> dict\n"
|
||||
' Returns {"results": [{"url", "title", "content", "error"}, ...]} where content is markdown',
|
||||
),
|
||||
(
|
||||
"read_file",
|
||||
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
|
||||
' Lines are 1-indexed. Returns {"content": "...", "total_lines": N}',
|
||||
),
|
||||
("write_file", " write_file(path: str, content: str) -> dict\n Always overwrites the entire file."),
|
||||
(
|
||||
"search_files",
|
||||
' search_files(pattern: str, target="content", path=".", file_glob=None, limit=50) -> dict\n'
|
||||
' target: "content" (search inside files) or "files" (find files by name). Returns {"matches": [...]}',
|
||||
),
|
||||
(
|
||||
"patch",
|
||||
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
|
||||
" Replaces old_string with new_string in the file.",
|
||||
),
|
||||
(
|
||||
"terminal",
|
||||
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
|
||||
' Foreground only (no background/pty). Returns {"output": "...", "exit_code": N}',
|
||||
),
|
||||
("web_search",
|
||||
" web_search(query: str, limit: int = 5) -> dict\n"
|
||||
" Returns {\"data\": {\"web\": [{\"url\", \"title\", \"description\"}, ...]}}"),
|
||||
("web_extract",
|
||||
" web_extract(urls: list[str]) -> dict\n"
|
||||
" Returns {\"results\": [{\"url\", \"title\", \"content\", \"error\"}, ...]} where content is markdown"),
|
||||
("read_file",
|
||||
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
|
||||
" Lines are 1-indexed. Returns {\"content\": \"...\", \"total_lines\": N}"),
|
||||
("write_file",
|
||||
" write_file(path: str, content: str) -> dict\n"
|
||||
" Always overwrites the entire file."),
|
||||
("search_files",
|
||||
" search_files(pattern: str, target=\"content\", path=\".\", file_glob=None, limit=50) -> dict\n"
|
||||
" target: \"content\" (search inside files) or \"files\" (find files by name). Returns {\"matches\": [...]}"),
|
||||
("patch",
|
||||
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
|
||||
" Replaces old_string with new_string in the file."),
|
||||
("terminal",
|
||||
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
|
||||
" Foreground only (no background/pty). Returns {\"output\": \"...\", \"exit_code\": N}"),
|
||||
]
|
||||
|
||||
|
||||
@@ -665,7 +639,9 @@ def build_execute_code_schema(enabled_sandbox_tools: set = None) -> dict:
|
||||
enabled_sandbox_tools = SANDBOX_ALLOWED_TOOLS
|
||||
|
||||
# Build tool documentation lines for only the enabled tools
|
||||
tool_lines = "\n".join(doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools)
|
||||
tool_lines = "\n".join(
|
||||
doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools
|
||||
)
|
||||
|
||||
# Build example import list from enabled tools
|
||||
import_examples = [n for n in ("web_search", "terminal") if n in enabled_sandbox_tools]
|
||||
@@ -726,7 +702,8 @@ registry.register(
|
||||
toolset="code_execution",
|
||||
schema=EXECUTE_CODE_SCHEMA,
|
||||
handler=lambda args, **kw: execute_code(
|
||||
code=args.get("code", ""), task_id=kw.get("task_id"), enabled_tools=kw.get("enabled_tools")
|
||||
),
|
||||
code=args.get("code", ""),
|
||||
task_id=kw.get("task_id"),
|
||||
enabled_tools=kw.get("enabled_tools")),
|
||||
check_fn=check_sandbox_requirements,
|
||||
)
|
||||
|
||||
@@ -11,44 +11,37 @@ The prompt must contain ALL necessary information.
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
# Import from cron module (will be available when properly installed)
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cron.jobs import create_job, get_job, list_jobs, remove_job
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cron prompt scanning — critical-severity patterns only, since cron prompts
|
||||
# run in fresh sessions with full tool access.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CRON_THREAT_PATTERNS = [
|
||||
(r"ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions", "prompt_injection"),
|
||||
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
|
||||
(r"system\s+prompt\s+override", "sys_prompt_override"),
|
||||
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
|
||||
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
|
||||
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
|
||||
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"),
|
||||
(r"authorized_keys", "ssh_backdoor"),
|
||||
(r"/etc/sudoers|visudo", "sudoers_mod"),
|
||||
(r"rm\s+-rf\s+/", "destructive_root_rm"),
|
||||
(r'ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions', "prompt_injection"),
|
||||
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
|
||||
(r'system\s+prompt\s+override', "sys_prompt_override"),
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
|
||||
(r'authorized_keys', "ssh_backdoor"),
|
||||
(r'/etc/sudoers|visudo', "sudoers_mod"),
|
||||
(r'rm\s+-rf\s+/', "destructive_root_rm"),
|
||||
]
|
||||
|
||||
_CRON_INVISIBLE_CHARS = {
|
||||
"\u200b",
|
||||
"\u200c",
|
||||
"\u200d",
|
||||
"\u2060",
|
||||
"\ufeff",
|
||||
"\u202a",
|
||||
"\u202b",
|
||||
"\u202c",
|
||||
"\u202d",
|
||||
"\u202e",
|
||||
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
|
||||
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
|
||||
}
|
||||
|
||||
|
||||
@@ -67,18 +60,17 @@ def _scan_cron_prompt(prompt: str) -> str:
|
||||
# Tool: schedule_cronjob
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def schedule_cronjob(
|
||||
prompt: str,
|
||||
schedule: str,
|
||||
name: str | None = None,
|
||||
repeat: int | None = None,
|
||||
deliver: str | None = None,
|
||||
task_id: str = None,
|
||||
name: Optional[str] = None,
|
||||
repeat: Optional[int] = None,
|
||||
deliver: Optional[str] = None,
|
||||
task_id: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Schedule an automated task to run the agent on a schedule.
|
||||
|
||||
|
||||
IMPORTANT: When the cronjob runs, it starts a COMPLETELY FRESH session.
|
||||
The agent will have NO memory of this conversation or any prior context.
|
||||
Therefore, the prompt MUST contain ALL necessary information:
|
||||
@@ -86,12 +78,12 @@ def schedule_cronjob(
|
||||
- Specific file paths, URLs, or identifiers
|
||||
- Clear success criteria
|
||||
- Any relevant background information
|
||||
|
||||
|
||||
BAD prompt: "Check on that server issue"
|
||||
GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx
|
||||
is running with 'systemctl status nginx', and verify the site
|
||||
GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx
|
||||
is running with 'systemctl status nginx', and verify the site
|
||||
https://example.com returns HTTP 200. Report any issues found."
|
||||
|
||||
|
||||
Args:
|
||||
prompt: Complete, self-contained instructions for the future agent.
|
||||
Must include ALL context needed - the agent won't remember anything.
|
||||
@@ -113,7 +105,7 @@ def schedule_cronjob(
|
||||
- "signal": Send to Signal home channel
|
||||
- "telegram:123456": Send to specific chat ID
|
||||
- "signal:+15551234567": Send to specific Signal number
|
||||
|
||||
|
||||
Returns:
|
||||
JSON with job_id, next_run time, and confirmation
|
||||
"""
|
||||
@@ -132,10 +124,17 @@ def schedule_cronjob(
|
||||
"chat_id": origin_chat_id,
|
||||
"chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"),
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
job = create_job(prompt=prompt, schedule=schedule, name=name, repeat=repeat, deliver=deliver, origin=origin)
|
||||
|
||||
job = create_job(
|
||||
prompt=prompt,
|
||||
schedule=schedule,
|
||||
name=name,
|
||||
repeat=repeat,
|
||||
deliver=deliver,
|
||||
origin=origin
|
||||
)
|
||||
|
||||
# Format repeat info for display
|
||||
times = job["repeat"].get("times")
|
||||
if times is None:
|
||||
@@ -144,23 +143,23 @@ def schedule_cronjob(
|
||||
repeat_display = "once"
|
||||
else:
|
||||
repeat_display = f"{times} times"
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_display,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job["next_run_at"],
|
||||
"message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}.",
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_display,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job["next_run_at"],
|
||||
"message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}."
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({"success": False, "error": str(e)}, indent=2)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, indent=2)
|
||||
|
||||
|
||||
SCHEDULE_CRONJOB_SCHEMA = {
|
||||
@@ -178,7 +177,7 @@ The future agent will NOT remember anything from the current conversation.
|
||||
|
||||
SCHEDULE FORMATS:
|
||||
- One-shot: "30m", "2h", "1d" (runs once after delay)
|
||||
- Interval: "every 30m", "every 2h" (recurring)
|
||||
- Interval: "every 30m", "every 2h" (recurring)
|
||||
- Cron: "0 9 * * *" (cron expression for precise scheduling)
|
||||
- Timestamp: "2026-02-03T14:00:00" (specific date/time)
|
||||
|
||||
@@ -203,24 +202,27 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance.""
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation.",
|
||||
"description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation."
|
||||
},
|
||||
"schedule": {
|
||||
"type": "string",
|
||||
"description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp",
|
||||
"description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional human-friendly name for the job"
|
||||
},
|
||||
"name": {"type": "string", "description": "Optional human-friendly name for the job"},
|
||||
"repeat": {
|
||||
"type": "integer",
|
||||
"description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs.",
|
||||
"description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs."
|
||||
},
|
||||
"deliver": {
|
||||
"type": "string",
|
||||
"description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'",
|
||||
},
|
||||
"description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'"
|
||||
}
|
||||
},
|
||||
"required": ["prompt", "schedule"],
|
||||
},
|
||||
"required": ["prompt", "schedule"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -228,11 +230,10 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance.""
|
||||
# Tool: list_cronjobs
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
|
||||
"""
|
||||
List all scheduled cronjobs.
|
||||
|
||||
|
||||
Returns information about each job including:
|
||||
- Job ID (needed for removal)
|
||||
- Name
|
||||
@@ -240,16 +241,16 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
|
||||
- Repeat status (completed/total or 'forever')
|
||||
- Next scheduled run time
|
||||
- Last run time and status (if any)
|
||||
|
||||
|
||||
Args:
|
||||
include_disabled: Whether to include disabled/completed jobs
|
||||
|
||||
|
||||
Returns:
|
||||
JSON array of all scheduled jobs
|
||||
"""
|
||||
try:
|
||||
jobs = list_jobs(include_disabled=include_disabled)
|
||||
|
||||
|
||||
formatted_jobs = []
|
||||
for job in jobs:
|
||||
# Format repeat status
|
||||
@@ -259,26 +260,31 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
|
||||
repeat_status = "forever"
|
||||
else:
|
||||
repeat_status = f"{completed}/{times}"
|
||||
|
||||
formatted_jobs.append(
|
||||
{
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_status,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job.get("next_run_at"),
|
||||
"last_run_at": job.get("last_run_at"),
|
||||
"last_status": job.get("last_status"),
|
||||
"enabled": job.get("enabled", True),
|
||||
}
|
||||
)
|
||||
|
||||
return json.dumps({"success": True, "count": len(formatted_jobs), "jobs": formatted_jobs}, indent=2)
|
||||
|
||||
|
||||
formatted_jobs.append({
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_status,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job.get("next_run_at"),
|
||||
"last_run_at": job.get("last_run_at"),
|
||||
"last_status": job.get("last_status"),
|
||||
"enabled": job.get("enabled", True)
|
||||
})
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"count": len(formatted_jobs),
|
||||
"jobs": formatted_jobs
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({"success": False, "error": str(e)}, indent=2)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, indent=2)
|
||||
|
||||
|
||||
LIST_CRONJOBS_SCHEMA = {
|
||||
@@ -296,11 +302,11 @@ Returns job_id, name, schedule, repeat status, next/last run times.""",
|
||||
"properties": {
|
||||
"include_disabled": {
|
||||
"type": "boolean",
|
||||
"description": "Include disabled/completed jobs in the list (default: false)",
|
||||
"description": "Include disabled/completed jobs in the list (default: false)"
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -308,45 +314,48 @@ Returns job_id, name, schedule, repeat status, next/last run times.""",
|
||||
# Tool: remove_cronjob
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def remove_cronjob(job_id: str, task_id: str = None) -> str:
|
||||
"""
|
||||
Remove a scheduled cronjob by its ID.
|
||||
|
||||
|
||||
Use list_cronjobs first to find the job_id of the job you want to remove.
|
||||
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job to remove (from list_cronjobs output)
|
||||
|
||||
|
||||
Returns:
|
||||
JSON confirmation of removal
|
||||
"""
|
||||
try:
|
||||
job = get_job(job_id)
|
||||
if not job:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs.",
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs."
|
||||
}, indent=2)
|
||||
|
||||
removed = remove_job(job_id)
|
||||
if removed:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.",
|
||||
"removed_job": {"id": job_id, "name": job["name"], "schedule": job["schedule_display"]},
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.",
|
||||
"removed_job": {
|
||||
"id": job_id,
|
||||
"name": job["name"],
|
||||
"schedule": job["schedule_display"]
|
||||
}
|
||||
}, indent=2)
|
||||
else:
|
||||
return json.dumps({"success": False, "error": f"Failed to remove job '{job_id}'"}, indent=2)
|
||||
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Failed to remove job '{job_id}'"
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({"success": False, "error": str(e)}, indent=2)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, indent=2)
|
||||
|
||||
|
||||
REMOVE_CRONJOB_SCHEMA = {
|
||||
@@ -359,10 +368,13 @@ use this to cancel a job before it completes.""",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": {"type": "string", "description": "The ID of the cronjob to remove (from list_cronjobs output)"}
|
||||
"job_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the cronjob to remove (from list_cronjobs output)"
|
||||
}
|
||||
},
|
||||
"required": ["job_id"],
|
||||
},
|
||||
"required": ["job_id"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -370,34 +382,44 @@ use this to cancel a job before it completes.""",
|
||||
# Requirements check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def check_cronjob_requirements() -> bool:
|
||||
"""
|
||||
Check if cronjob tools can be used.
|
||||
|
||||
|
||||
Available in interactive CLI mode and gateway/messaging platforms.
|
||||
Cronjobs are server-side scheduled tasks so they work from any interface.
|
||||
"""
|
||||
return bool(os.getenv("HERMES_INTERACTIVE") or os.getenv("HERMES_GATEWAY_SESSION") or os.getenv("HERMES_EXEC_ASK"))
|
||||
return bool(
|
||||
os.getenv("HERMES_INTERACTIVE")
|
||||
or os.getenv("HERMES_GATEWAY_SESSION")
|
||||
or os.getenv("HERMES_EXEC_ASK")
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exports
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_cronjob_tool_definitions():
|
||||
"""Return tool definitions for cronjob management."""
|
||||
return [SCHEDULE_CRONJOB_SCHEMA, LIST_CRONJOBS_SCHEMA, REMOVE_CRONJOB_SCHEMA]
|
||||
return [
|
||||
SCHEDULE_CRONJOB_SCHEMA,
|
||||
LIST_CRONJOBS_SCHEMA,
|
||||
REMOVE_CRONJOB_SCHEMA
|
||||
]
|
||||
|
||||
|
||||
# For direct testing
|
||||
if __name__ == "__main__":
|
||||
# Test the tools
|
||||
print("Testing schedule_cronjob:")
|
||||
result = schedule_cronjob(prompt="Test prompt for cron job", schedule="5m", name="Test Job")
|
||||
result = schedule_cronjob(
|
||||
prompt="Test prompt for cron job",
|
||||
schedule="5m",
|
||||
name="Test Job"
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
print("\nTesting list_cronjobs:")
|
||||
result = list_cronjobs()
|
||||
print(result)
|
||||
@@ -416,8 +438,7 @@ registry.register(
|
||||
name=args.get("name"),
|
||||
repeat=args.get("repeat"),
|
||||
deliver=args.get("deliver"),
|
||||
task_id=kw.get("task_id"),
|
||||
),
|
||||
task_id=kw.get("task_id")),
|
||||
check_fn=check_cronjob_requirements,
|
||||
)
|
||||
registry.register(
|
||||
@@ -425,14 +446,16 @@ registry.register(
|
||||
toolset="cronjob",
|
||||
schema=LIST_CRONJOBS_SCHEMA,
|
||||
handler=lambda args, **kw: list_cronjobs(
|
||||
include_disabled=args.get("include_disabled", False), task_id=kw.get("task_id")
|
||||
),
|
||||
include_disabled=args.get("include_disabled", False),
|
||||
task_id=kw.get("task_id")),
|
||||
check_fn=check_cronjob_requirements,
|
||||
)
|
||||
registry.register(
|
||||
name="remove_cronjob",
|
||||
toolset="cronjob",
|
||||
schema=REMOVE_CRONJOB_SCHEMA,
|
||||
handler=lambda args, **kw: remove_cronjob(job_id=args.get("job_id", ""), task_id=kw.get("task_id")),
|
||||
handler=lambda args, **kw: remove_cronjob(
|
||||
job_id=args.get("job_id", ""),
|
||||
task_id=kw.get("task_id")),
|
||||
check_fn=check_cronjob_requirements,
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,28 +44,27 @@ class DebugSession:
|
||||
self.enabled = os.getenv(env_var, "false").lower() == "true"
|
||||
self.session_id = str(uuid.uuid4()) if self.enabled else ""
|
||||
self.log_dir = Path("./logs")
|
||||
self._calls: list[dict[str, Any]] = []
|
||||
self._calls: list[Dict[str, Any]] = []
|
||||
self._start_time = datetime.datetime.now().isoformat() if self.enabled else ""
|
||||
|
||||
if self.enabled:
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
logger.debug("%s debug mode enabled - Session ID: %s", tool_name, self.session_id)
|
||||
logger.debug("%s debug mode enabled - Session ID: %s",
|
||||
tool_name, self.session_id)
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
return self.enabled
|
||||
|
||||
def log_call(self, call_name: str, call_data: dict[str, Any]) -> None:
|
||||
def log_call(self, call_name: str, call_data: Dict[str, Any]) -> None:
|
||||
"""Append a tool-call entry to the in-memory log."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self._calls.append(
|
||||
{
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": call_name,
|
||||
**call_data,
|
||||
}
|
||||
)
|
||||
self._calls.append({
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": call_name,
|
||||
**call_data,
|
||||
})
|
||||
|
||||
def save(self) -> None:
|
||||
"""Flush the in-memory log to a JSON file in the logs directory."""
|
||||
@@ -88,7 +87,7 @@ class DebugSession:
|
||||
except Exception as e:
|
||||
logger.error("Error saving %s debug log: %s", self.tool_name, e)
|
||||
|
||||
def get_session_info(self) -> dict[str, Any]:
|
||||
def get_session_info(self) -> Dict[str, Any]:
|
||||
"""Return a summary dict suitable for returning from get_debug_session_info()."""
|
||||
if not self.enabled:
|
||||
return {
|
||||
|
||||
@@ -20,22 +20,21 @@ import contextlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
# Tools that children must never have access to
|
||||
DELEGATE_BLOCKED_TOOLS = frozenset(
|
||||
[
|
||||
"delegate_task", # no recursive delegation
|
||||
"clarify", # no user interaction
|
||||
"memory", # no writes to shared MEMORY.md
|
||||
"send_message", # no cross-platform side effects
|
||||
"execute_code", # children should reason step-by-step, not write scripts
|
||||
]
|
||||
)
|
||||
DELEGATE_BLOCKED_TOOLS = frozenset([
|
||||
"delegate_task", # no recursive delegation
|
||||
"clarify", # no user interaction
|
||||
"memory", # no writes to shared MEMORY.md
|
||||
"send_message", # no cross-platform side effects
|
||||
"execute_code", # children should reason step-by-step, not write scripts
|
||||
])
|
||||
|
||||
MAX_CONCURRENT_CHILDREN = 3
|
||||
MAX_DEPTH = 2 # parent (0) -> child (1) -> grandchild rejected (2)
|
||||
@@ -48,7 +47,7 @@ def check_delegate_requirements() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _build_child_system_prompt(goal: str, context: str | None = None) -> str:
|
||||
def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
|
||||
"""Build a focused system prompt for a child agent."""
|
||||
parts = [
|
||||
"You are a focused subagent working on a specific delegated task.",
|
||||
@@ -70,18 +69,15 @@ def _build_child_system_prompt(goal: str, context: str | None = None) -> str:
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _strip_blocked_tools(toolsets: list[str]) -> list[str]:
|
||||
def _strip_blocked_tools(toolsets: List[str]) -> List[str]:
|
||||
"""Remove toolsets that contain only blocked tools."""
|
||||
blocked_toolset_names = {
|
||||
"delegation",
|
||||
"clarify",
|
||||
"memory",
|
||||
"code_execution",
|
||||
"delegation", "clarify", "memory", "code_execution",
|
||||
}
|
||||
return [t for t in toolsets if t not in blocked_toolset_names]
|
||||
|
||||
|
||||
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Callable | None:
|
||||
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Optional[callable]:
|
||||
"""Build a callback that relays child agent tool calls to the parent display.
|
||||
|
||||
Two display paths:
|
||||
@@ -91,8 +87,8 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
||||
Returns None if no display mechanism is available, in which case the
|
||||
child agent runs with no progress callback (identical to current behavior).
|
||||
"""
|
||||
spinner = getattr(parent_agent, "_delegate_spinner", None)
|
||||
parent_cb = getattr(parent_agent, "tool_progress_callback", None)
|
||||
spinner = getattr(parent_agent, '_delegate_spinner', None)
|
||||
parent_cb = getattr(parent_agent, 'tool_progress_callback', None)
|
||||
|
||||
if not spinner and not parent_cb:
|
||||
return None # No display → no callback → zero behavior change
|
||||
@@ -102,7 +98,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
||||
|
||||
# Gateway: batch tool names, flush periodically
|
||||
_BATCH_SIZE = 5
|
||||
_batch: list[str] = []
|
||||
_batch: List[str] = []
|
||||
|
||||
def _callback(tool_name: str, preview: str = None):
|
||||
# Special "_thinking" event: model produced text content (reasoning)
|
||||
@@ -110,7 +106,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
||||
if spinner:
|
||||
short = (preview[:55] + "...") if preview and len(preview) > 55 else (preview or "")
|
||||
try:
|
||||
spinner.print_above(f' {prefix}├─ 💭 "{short}"')
|
||||
spinner.print_above(f" {prefix}├─ 💭 \"{short}\"")
|
||||
except Exception:
|
||||
pass
|
||||
# Don't relay thinking to gateway (too noisy for chat)
|
||||
@@ -120,25 +116,17 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
||||
if spinner:
|
||||
short = (preview[:35] + "...") if preview and len(preview) > 35 else (preview or "")
|
||||
tool_emojis = {
|
||||
"terminal": "💻",
|
||||
"web_search": "🔍",
|
||||
"web_extract": "📄",
|
||||
"read_file": "📖",
|
||||
"write_file": "✍️",
|
||||
"patch": "🔧",
|
||||
"search_files": "🔎",
|
||||
"list_directory": "📂",
|
||||
"browser_navigate": "🌐",
|
||||
"browser_click": "👆",
|
||||
"text_to_speech": "🔊",
|
||||
"image_generate": "🎨",
|
||||
"vision_analyze": "👁️",
|
||||
"process": "⚙️",
|
||||
"terminal": "💻", "web_search": "🔍", "web_extract": "📄",
|
||||
"read_file": "📖", "write_file": "✍️", "patch": "🔧",
|
||||
"search_files": "🔎", "list_directory": "📂",
|
||||
"browser_navigate": "🌐", "browser_click": "👆",
|
||||
"text_to_speech": "🔊", "image_generate": "🎨",
|
||||
"vision_analyze": "👁️", "process": "⚙️",
|
||||
}
|
||||
emoji = tool_emojis.get(tool_name, "⚡")
|
||||
line = f" {prefix}├─ {emoji} {tool_name}"
|
||||
if short:
|
||||
line += f' "{short}"'
|
||||
line += f" \"{short}\""
|
||||
try:
|
||||
spinner.print_above(line)
|
||||
except Exception:
|
||||
@@ -171,13 +159,13 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
||||
def _run_single_child(
|
||||
task_index: int,
|
||||
goal: str,
|
||||
context: str | None,
|
||||
toolsets: list[str] | None,
|
||||
model: str | None,
|
||||
context: Optional[str],
|
||||
toolsets: Optional[List[str]],
|
||||
model: Optional[str],
|
||||
max_iterations: int,
|
||||
parent_agent,
|
||||
task_count: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Spawn and run a single child agent. Called from within a thread.
|
||||
Returns a structured result dict.
|
||||
@@ -228,7 +216,7 @@ def _run_single_child(
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
clarify_callback=None,
|
||||
session_db=getattr(parent_agent, "_session_db", None),
|
||||
session_db=getattr(parent_agent, '_session_db', None),
|
||||
providers_allowed=parent_agent.providers_allowed,
|
||||
providers_ignored=parent_agent.providers_ignored,
|
||||
providers_order=parent_agent.providers_order,
|
||||
@@ -238,10 +226,10 @@ def _run_single_child(
|
||||
)
|
||||
|
||||
# Set delegation depth so children can't spawn grandchildren
|
||||
child._delegate_depth = getattr(parent_agent, "_delegate_depth", 0) + 1
|
||||
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
|
||||
|
||||
# Register child for interrupt propagation
|
||||
if hasattr(parent_agent, "_active_children"):
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
parent_agent._active_children.append(child)
|
||||
|
||||
# Run with stdout/stderr suppressed to prevent interleaved output
|
||||
@@ -250,7 +238,7 @@ def _run_single_child(
|
||||
result = child.run_conversation(user_message=goal)
|
||||
|
||||
# Flush any remaining batched progress to gateway
|
||||
if child_progress_cb and hasattr(child_progress_cb, "_flush"):
|
||||
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
|
||||
try:
|
||||
child_progress_cb._flush()
|
||||
except Exception:
|
||||
@@ -270,7 +258,7 @@ def _run_single_child(
|
||||
else:
|
||||
status = "failed"
|
||||
|
||||
entry: dict[str, Any] = {
|
||||
entry: Dict[str, Any] = {
|
||||
"task_index": task_index,
|
||||
"status": status,
|
||||
"summary": summary,
|
||||
@@ -296,7 +284,7 @@ def _run_single_child(
|
||||
|
||||
finally:
|
||||
# Unregister child from interrupt propagation
|
||||
if hasattr(parent_agent, "_active_children"):
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
try:
|
||||
parent_agent._active_children.remove(child)
|
||||
except (ValueError, UnboundLocalError):
|
||||
@@ -304,11 +292,11 @@ def _run_single_child(
|
||||
|
||||
|
||||
def delegate_task(
|
||||
goal: str | None = None,
|
||||
context: str | None = None,
|
||||
toolsets: list[str] | None = None,
|
||||
tasks: list[dict[str, Any]] | None = None,
|
||||
max_iterations: int | None = None,
|
||||
goal: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
toolsets: Optional[List[str]] = None,
|
||||
tasks: Optional[List[Dict[str, Any]]] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
parent_agent=None,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -324,11 +312,14 @@ def delegate_task(
|
||||
return json.dumps({"error": "delegate_task requires a parent agent context."})
|
||||
|
||||
# Depth limit
|
||||
depth = getattr(parent_agent, "_delegate_depth", 0)
|
||||
depth = getattr(parent_agent, '_delegate_depth', 0)
|
||||
if depth >= MAX_DEPTH:
|
||||
return json.dumps(
|
||||
{"error": (f"Delegation depth limit reached ({MAX_DEPTH}). Subagents cannot spawn further subagents.")}
|
||||
)
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"Delegation depth limit reached ({MAX_DEPTH}). "
|
||||
"Subagents cannot spawn further subagents."
|
||||
)
|
||||
})
|
||||
|
||||
# Load config
|
||||
cfg = _load_config()
|
||||
@@ -375,7 +366,7 @@ def delegate_task(
|
||||
else:
|
||||
# Batch -- run in parallel with per-task progress lines
|
||||
completed_count = 0
|
||||
spinner_ref = getattr(parent_agent, "_delegate_spinner", None)
|
||||
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
|
||||
|
||||
# Save stdout/stderr before the executor — redirect_stdout in child
|
||||
# threads races on sys.stdout and can leave it as devnull permanently.
|
||||
@@ -421,7 +412,7 @@ def delegate_task(
|
||||
status = entry.get("status", "?")
|
||||
icon = "✓" if status == "completed" else "✗"
|
||||
remaining = n_tasks - completed_count
|
||||
completion_line = f"{icon} [{idx + 1}/{n_tasks}] {label} ({dur}s)"
|
||||
completion_line = f"{icon} [{idx+1}/{n_tasks}] {label} ({dur}s)"
|
||||
if spinner_ref:
|
||||
try:
|
||||
spinner_ref.print_above(completion_line)
|
||||
@@ -446,20 +437,16 @@ def delegate_task(
|
||||
|
||||
total_duration = round(time.monotonic() - overall_start, 2)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"results": results,
|
||||
"total_duration_seconds": total_duration,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
return json.dumps({
|
||||
"results": results,
|
||||
"total_duration_seconds": total_duration,
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Load delegation config from CLI_CONFIG if available."""
|
||||
try:
|
||||
from cli import CLI_CONFIG
|
||||
|
||||
return CLI_CONFIG.get("delegation", {})
|
||||
except Exception:
|
||||
return {}
|
||||
@@ -550,7 +537,10 @@ DELEGATE_TASK_SCHEMA = {
|
||||
},
|
||||
"max_iterations": {
|
||||
"type": "integer",
|
||||
"description": ("Max tool-calling turns per subagent (default: 50). Only set lower for simple tasks."),
|
||||
"description": (
|
||||
"Max tool-calling turns per subagent (default: 50). "
|
||||
"Only set lower for simple tasks."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
@@ -571,7 +561,6 @@ registry.register(
|
||||
toolsets=args.get("toolsets"),
|
||||
tasks=args.get("tasks"),
|
||||
max_iterations=args.get("max_iterations"),
|
||||
parent_agent=kw.get("parent_agent"),
|
||||
),
|
||||
parent_agent=kw.get("parent_agent")),
|
||||
check_fn=check_delegate_requirements,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Base class for all Hermes execution environment backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -34,9 +34,9 @@ class BaseEnvironment(ABC):
|
||||
self.env = env or {}
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||
...
|
||||
|
||||
@@ -62,10 +62,10 @@ class BaseEnvironment(ABC):
|
||||
def _prepare_command(self, command: str) -> str:
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available."""
|
||||
from tools.terminal_tool import _transform_sudo_command
|
||||
|
||||
return _transform_sudo_command(command)
|
||||
|
||||
def _build_run_kwargs(self, timeout: int | None, stdin_data: str | None = None) -> dict:
|
||||
def _build_run_kwargs(self, timeout: int | None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Build common subprocess.run kwargs for non-interactive execution."""
|
||||
kw = {
|
||||
"text": True,
|
||||
|
||||
@@ -11,6 +11,7 @@ import shlex
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
@@ -31,8 +32,8 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
cwd: str = "/home/daytona",
|
||||
timeout: int = 60,
|
||||
cpu: int = 1,
|
||||
memory: int = 5120, # MB (hermes convention)
|
||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
||||
memory: int = 5120, # MB (hermes convention)
|
||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
@@ -40,8 +41,8 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
from daytona import (
|
||||
CreateSandboxFromImageParams,
|
||||
Daytona,
|
||||
CreateSandboxFromImageParams,
|
||||
DaytonaError,
|
||||
Resources,
|
||||
SandboxState,
|
||||
@@ -72,11 +73,13 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
try:
|
||||
self._sandbox = self._daytona.find_one(labels=labels)
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: resumed sandbox %s for task %s", self._sandbox.id, task_id)
|
||||
logger.info("Daytona: resumed sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
except DaytonaError:
|
||||
self._sandbox = None
|
||||
except Exception as e:
|
||||
logger.warning("Daytona: failed to resume sandbox for task %s: %s", task_id, e)
|
||||
logger.warning("Daytona: failed to resume sandbox for task %s: %s",
|
||||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# Create a fresh sandbox if we don't have one
|
||||
@@ -89,7 +92,8 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
resources=resources,
|
||||
)
|
||||
)
|
||||
logger.info("Daytona: created sandbox %s for task %s", self._sandbox.id, task_id)
|
||||
logger.info("Daytona: created sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
|
||||
# Resolve cwd: detect actual home dir inside the sandbox
|
||||
if self._requested_cwd in ("~", "/home/daytona"):
|
||||
@@ -108,7 +112,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
|
||||
|
||||
def _exec_in_thread(self, exec_command: str, cwd: str | None, timeout: int) -> dict:
|
||||
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
|
||||
"""Run exec in a background thread with interrupt polling.
|
||||
|
||||
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
|
||||
@@ -126,8 +130,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
def _run():
|
||||
try:
|
||||
response = self._sandbox.process.exec(
|
||||
timed_command,
|
||||
cwd=cwd,
|
||||
timed_command, cwd=cwd,
|
||||
)
|
||||
result_holder["value"] = {
|
||||
"output": response.result or "",
|
||||
@@ -166,9 +169,9 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
return {"error": result_holder["error"]}
|
||||
return result_holder["value"]
|
||||
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: Optional[int] = None,
|
||||
stdin_data: Optional[str] = None) -> dict:
|
||||
with self._lock:
|
||||
self._ensure_sandbox_ready()
|
||||
|
||||
@@ -186,7 +189,6 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
|
||||
if "error" in result:
|
||||
from daytona import DaytonaError
|
||||
|
||||
err = result["error"]
|
||||
if isinstance(err, DaytonaError):
|
||||
with self._lock:
|
||||
@@ -208,7 +210,8 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
try:
|
||||
if self._persistent:
|
||||
self._sandbox.stop()
|
||||
logger.info("Daytona: stopped sandbox %s (filesystem preserved)", self._sandbox.id)
|
||||
logger.info("Daytona: stopped sandbox %s (filesystem preserved)",
|
||||
self._sandbox.id)
|
||||
else:
|
||||
self._daytona.delete(self._sandbox)
|
||||
logger.info("Daytona: deleted sandbox %s", self._sandbox.id)
|
||||
|
||||
@@ -11,6 +11,7 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
@@ -18,36 +19,22 @@ from tools.interrupt import is_interrupted
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
# Security flags applied to every container.
|
||||
# The container itself is the security boundary (isolated from host).
|
||||
# We drop all capabilities then add back the minimum needed:
|
||||
# DAC_OVERRIDE - root can write to bind-mounted dirs owned by host user
|
||||
# CHOWN/FOWNER - package managers (pip, npm, apt) need to set file ownership
|
||||
# Block privilege escalation and limit PIDs.
|
||||
# We drop all capabilities, block privilege escalation, and limit PIDs.
|
||||
# /tmp is size-limited and nosuid but allows exec (needed by pip/npm builds).
|
||||
_SECURITY_ARGS = [
|
||||
"--cap-drop",
|
||||
"ALL",
|
||||
"--cap-add",
|
||||
"DAC_OVERRIDE",
|
||||
"--cap-add",
|
||||
"CHOWN",
|
||||
"--cap-add",
|
||||
"FOWNER",
|
||||
"--security-opt",
|
||||
"no-new-privileges",
|
||||
"--pids-limit",
|
||||
"256",
|
||||
"--tmpfs",
|
||||
"/tmp:rw,nosuid,size=512m",
|
||||
"--tmpfs",
|
||||
"/var/tmp:rw,noexec,nosuid,size=256m",
|
||||
"--tmpfs",
|
||||
"/run:rw,noexec,nosuid,size=64m",
|
||||
"--cap-drop", "ALL",
|
||||
"--security-opt", "no-new-privileges",
|
||||
"--pids-limit", "256",
|
||||
"--tmpfs", "/tmp:rw,nosuid,size=512m",
|
||||
"--tmpfs", "/var/tmp:rw,noexec,nosuid,size=256m",
|
||||
"--tmpfs", "/run:rw,noexec,nosuid,size=64m",
|
||||
]
|
||||
|
||||
|
||||
_storage_opt_ok: bool | None = None # cached result across instances
|
||||
_storage_opt_ok: Optional[bool] = None # cached result across instances
|
||||
|
||||
|
||||
class DockerEnvironment(BaseEnvironment):
|
||||
@@ -81,7 +68,7 @@ class DockerEnvironment(BaseEnvironment):
|
||||
self._base_image = image
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._container_id: str | None = None
|
||||
self._container_id: Optional[str] = None
|
||||
logger.info(f"DockerEnvironment volumes: {volumes}")
|
||||
# Ensure volumes is a list (config.yaml could be malformed)
|
||||
if volumes is not None and not isinstance(volumes, list):
|
||||
@@ -112,8 +99,8 @@ class DockerEnvironment(BaseEnvironment):
|
||||
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
|
||||
from tools.environments.base import get_sandbox_dir
|
||||
|
||||
self._workspace_dir: str | None = None
|
||||
self._home_dir: str | None = None
|
||||
self._workspace_dir: Optional[str] = None
|
||||
self._home_dir: Optional[str] = None
|
||||
if self._persistent:
|
||||
sandbox = get_sandbox_dir() / "docker" / task_id
|
||||
self._workspace_dir = str(sandbox / "workspace")
|
||||
@@ -121,19 +108,14 @@ class DockerEnvironment(BaseEnvironment):
|
||||
os.makedirs(self._workspace_dir, exist_ok=True)
|
||||
os.makedirs(self._home_dir, exist_ok=True)
|
||||
writable_args = [
|
||||
"-v",
|
||||
f"{self._workspace_dir}:/workspace",
|
||||
"-v",
|
||||
f"{self._home_dir}:/root",
|
||||
"-v", f"{self._workspace_dir}:/workspace",
|
||||
"-v", f"{self._home_dir}:/root",
|
||||
]
|
||||
else:
|
||||
writable_args = [
|
||||
"--tmpfs",
|
||||
"/workspace:rw,exec,size=10g",
|
||||
"--tmpfs",
|
||||
"/home:rw,exec,size=1g",
|
||||
"--tmpfs",
|
||||
"/root:rw,exec,size=1g",
|
||||
"--tmpfs", "/workspace:rw,exec,size=10g",
|
||||
"--tmpfs", "/home:rw,exec,size=1g",
|
||||
"--tmpfs", "/root:rw,exec,size=1g",
|
||||
]
|
||||
|
||||
# All containers get security hardening (capabilities dropped, no privilege
|
||||
@@ -141,7 +123,7 @@ class DockerEnvironment(BaseEnvironment):
|
||||
# can install packages as needed.
|
||||
# User-configured volume mounts (from config.yaml docker_volumes)
|
||||
volume_args = []
|
||||
for vol in volumes or []:
|
||||
for vol in (volumes or []):
|
||||
if not isinstance(vol, str):
|
||||
logger.warning(f"Docker volume entry is not a string: {vol!r}")
|
||||
continue
|
||||
@@ -158,9 +140,7 @@ class DockerEnvironment(BaseEnvironment):
|
||||
logger.info(f"Docker run_args: {all_run_args}")
|
||||
|
||||
self._inner = _Docker(
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
run_args=all_run_args,
|
||||
)
|
||||
self._container_id = self._inner.container_id
|
||||
@@ -168,7 +148,7 @@ class DockerEnvironment(BaseEnvironment):
|
||||
@staticmethod
|
||||
def _storage_opt_supported() -> bool:
|
||||
"""Check if Docker's storage driver supports --storage-opt size=.
|
||||
|
||||
|
||||
Only overlay2 on XFS with pquota supports per-container disk quotas.
|
||||
Ubuntu (and most distros) default to ext4, where this flag errors out.
|
||||
"""
|
||||
@@ -178,9 +158,7 @@ class DockerEnvironment(BaseEnvironment):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info", "--format", "{{.Driver}}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
driver = result.stdout.strip().lower()
|
||||
if driver != "overlay2":
|
||||
@@ -190,15 +168,14 @@ class DockerEnvironment(BaseEnvironment):
|
||||
# Probe by attempting a dry-ish run — the fastest reliable check.
|
||||
probe = subprocess.run(
|
||||
["docker", "create", "--storage-opt", "size=1m", "hello-world"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if probe.returncode == 0:
|
||||
# Clean up the created container
|
||||
container_id = probe.stdout.strip()
|
||||
if container_id:
|
||||
subprocess.run(["docker", "rm", container_id], capture_output=True, timeout=5)
|
||||
subprocess.run(["docker", "rm", container_id],
|
||||
capture_output=True, timeout=5)
|
||||
_storage_opt_ok = True
|
||||
else:
|
||||
_storage_opt_ok = False
|
||||
@@ -207,9 +184,9 @@ class DockerEnvironment(BaseEnvironment):
|
||||
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
|
||||
return _storage_opt_ok
|
||||
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
exec_command = self._prepare_command(command)
|
||||
work_dir = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
@@ -235,8 +212,7 @@ class DockerEnvironment(BaseEnvironment):
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
@@ -287,7 +263,6 @@ class DockerEnvironment(BaseEnvironment):
|
||||
|
||||
if not self._persistent:
|
||||
import shutil
|
||||
|
||||
for d in (self._workspace_dir, self._home_dir):
|
||||
if d:
|
||||
shutil.rmtree(d, ignore_errors=True)
|
||||
|
||||
@@ -154,9 +154,9 @@ class LocalEnvironment(BaseEnvironment):
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
|
||||
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
|
||||
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
|
||||
work_dir = cwd or self.cwd or os.getcwd()
|
||||
@@ -172,7 +172,11 @@ class LocalEnvironment(BaseEnvironment):
|
||||
# Wrap with output fences so we can later extract the real
|
||||
# command output and discard shell init/exit noise.
|
||||
fenced_cmd = (
|
||||
f"printf '{_OUTPUT_FENCE}'; {exec_command}; __hermes_rc=$?; printf '{_OUTPUT_FENCE}'; exit $__hermes_rc"
|
||||
f"printf '{_OUTPUT_FENCE}';"
|
||||
f" {exec_command};"
|
||||
f" __hermes_rc=$?;"
|
||||
f" printf '{_OUTPUT_FENCE}';"
|
||||
f" exit $__hermes_rc"
|
||||
)
|
||||
# Ensure PATH always includes standard dirs — systemd services
|
||||
# and some terminal multiplexers inherit a minimal PATH.
|
||||
@@ -196,14 +200,12 @@ class LocalEnvironment(BaseEnvironment):
|
||||
)
|
||||
|
||||
if stdin_data is not None:
|
||||
|
||||
def _write_stdin():
|
||||
try:
|
||||
proc.stdin.write(stdin_data)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
threading.Thread(target=_write_stdin, daemon=True).start()
|
||||
|
||||
_output_chunks: list[str] = []
|
||||
|
||||
@@ -8,9 +8,10 @@ project files, and config changes survive across sessions.
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
@@ -20,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
_SNAPSHOT_STORE = Path.home() / ".hermes" / "modal_snapshots.json"
|
||||
|
||||
|
||||
def _load_snapshots() -> dict[str, str]:
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
"""Load snapshot ID mapping from disk."""
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
@@ -30,7 +31,7 @@ def _load_snapshots() -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_snapshots(data: dict[str, str]) -> None:
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
"""Persist snapshot ID mapping to disk."""
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
@@ -51,7 +52,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||
image: str,
|
||||
cwd: str = "~",
|
||||
timeout: int = 60,
|
||||
modal_sandbox_kwargs: dict[str, Any] | None = None,
|
||||
modal_sandbox_kwargs: Optional[Dict[str, Any]] = None,
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
@@ -60,7 +61,6 @@ class ModalEnvironment(BaseEnvironment):
|
||||
if not ModalEnvironment._patches_applied:
|
||||
try:
|
||||
from environments.patches import apply_patches
|
||||
|
||||
apply_patches()
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -79,7 +79,6 @@ class ModalEnvironment(BaseEnvironment):
|
||||
if snapshot_id:
|
||||
try:
|
||||
import modal
|
||||
|
||||
restored_image = modal.Image.from_id(snapshot_id)
|
||||
logger.info("Modal: restoring from snapshot %s", snapshot_id[:20])
|
||||
except Exception as e:
|
||||
@@ -89,7 +88,6 @@ class ModalEnvironment(BaseEnvironment):
|
||||
effective_image = restored_image if restored_image else image
|
||||
|
||||
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
|
||||
|
||||
self._inner = SwerexModalEnvironment(
|
||||
image=effective_image,
|
||||
cwd=cwd,
|
||||
@@ -99,9 +97,9 @@ class ModalEnvironment(BaseEnvironment):
|
||||
modal_sandbox_kwargs=sandbox_kwargs,
|
||||
)
|
||||
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if stdin_data is not None:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
@@ -141,29 +139,29 @@ class ModalEnvironment(BaseEnvironment):
|
||||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||
if self._persistent:
|
||||
try:
|
||||
sandbox = getattr(self._inner, "deployment", None)
|
||||
sandbox = getattr(sandbox, "_sandbox", None) if sandbox else None
|
||||
sandbox = getattr(self._inner, 'deployment', None)
|
||||
sandbox = getattr(sandbox, '_sandbox', None) if sandbox else None
|
||||
if sandbox:
|
||||
import asyncio
|
||||
|
||||
async def _snapshot():
|
||||
img = await sandbox.snapshot_filesystem.aio()
|
||||
return img.object_id
|
||||
|
||||
try:
|
||||
snapshot_id = asyncio.run(_snapshot())
|
||||
except RuntimeError:
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
snapshot_id = pool.submit(asyncio.run, _snapshot()).result(timeout=60)
|
||||
snapshot_id = pool.submit(
|
||||
asyncio.run, _snapshot()
|
||||
).result(timeout=60)
|
||||
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = snapshot_id
|
||||
_save_snapshots(snapshots)
|
||||
logger.info("Modal: saved filesystem snapshot %s for task %s", snapshot_id[:20], self._task_id)
|
||||
logger.info("Modal: saved filesystem snapshot %s for task %s",
|
||||
snapshot_id[:20], self._task_id)
|
||||
except Exception as e:
|
||||
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
||||
|
||||
if hasattr(self._inner, "stop"):
|
||||
if hasattr(self._inner, 'stop'):
|
||||
self._inner.stop()
|
||||
|
||||
@@ -10,9 +10,11 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
@@ -22,7 +24,7 @@ logger = logging.getLogger(__name__)
|
||||
_SNAPSHOT_STORE = Path.home() / ".hermes" / "singularity_snapshots.json"
|
||||
|
||||
|
||||
def _load_snapshots() -> dict[str, str]:
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
@@ -31,7 +33,7 @@ def _load_snapshots() -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_snapshots(data: dict[str, str]) -> None:
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
|
||||
@@ -40,7 +42,6 @@ def _save_snapshots(data: dict[str, str]) -> None:
|
||||
# Singularity helpers (scratch dir, SIF cache, SIF building)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_scratch_dir() -> Path:
|
||||
"""Get the best directory for Singularity sandboxes.
|
||||
|
||||
@@ -57,7 +58,6 @@ def _get_scratch_dir() -> Path:
|
||||
return scratch_path
|
||||
|
||||
from tools.environments.base import get_sandbox_dir
|
||||
|
||||
sandbox = get_sandbox_dir() / "singularity"
|
||||
|
||||
scratch = Path("/scratch")
|
||||
@@ -93,12 +93,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
Returns the path unchanged if it's already a .sif file.
|
||||
For docker:// URLs, checks the cache and builds if needed.
|
||||
"""
|
||||
if image.endswith(".sif") and Path(image).exists():
|
||||
if image.endswith('.sif') and Path(image).exists():
|
||||
return image
|
||||
if not image.startswith("docker://"):
|
||||
if not image.startswith('docker://'):
|
||||
return image
|
||||
|
||||
image_name = image.replace("docker://", "").replace("/", "-").replace(":", "-")
|
||||
image_name = image.replace('docker://', '').replace('/', '-').replace(':', '-')
|
||||
cache_dir = _get_apptainer_cache_dir()
|
||||
sif_path = cache_dir / f"{image_name}.sif"
|
||||
|
||||
@@ -123,10 +123,7 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[executable, "build", str(sif_path), image],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600,
|
||||
env=env,
|
||||
capture_output=True, text=True, timeout=600, env=env,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning("SIF build failed, falling back to docker:// URL")
|
||||
@@ -148,7 +145,6 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
# SingularityEnvironment
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SingularityEnvironment(BaseEnvironment):
|
||||
"""Hardened Singularity/Apptainer container with resource limits and persistence.
|
||||
|
||||
@@ -178,7 +174,7 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
self._instance_started = False
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._overlay_dir: Path | None = None
|
||||
self._overlay_dir: Optional[Path] = None
|
||||
|
||||
# Resource limits
|
||||
self._cpu = cpu
|
||||
@@ -219,13 +215,14 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to start instance: {result.stderr}")
|
||||
self._instance_started = True
|
||||
logger.info("Singularity instance %s started (persistent=%s)", self.instance_id, self._persistent)
|
||||
logger.info("Singularity instance %s started (persistent=%s)",
|
||||
self.instance_id, self._persistent)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError("Instance start timed out")
|
||||
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if not self._instance_started:
|
||||
return {"output": "Instance not started", "returncode": -1}
|
||||
|
||||
@@ -238,16 +235,16 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
exec_command = f"cd {work_dir} && {exec_command}"
|
||||
work_dir = "/tmp"
|
||||
|
||||
cmd = [self.executable, "exec", "--pwd", work_dir, f"instance://{self.instance_id}", "bash", "-c", exec_command]
|
||||
cmd = [self.executable, "exec", "--pwd", work_dir,
|
||||
f"instance://{self.instance_id}",
|
||||
"bash", "-c", exec_command]
|
||||
|
||||
try:
|
||||
import time as _time
|
||||
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
@@ -298,9 +295,7 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
try:
|
||||
subprocess.run(
|
||||
[self.executable, "instance", "stop", self.instance_id],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
capture_output=True, text=True, timeout=30,
|
||||
)
|
||||
logger.info("Singularity instance %s stopped", self.instance_id)
|
||||
except Exception as e:
|
||||
|
||||
@@ -24,7 +24,8 @@ class SSHEnvironment(BaseEnvironment):
|
||||
and a remote kill is attempted over the ControlMaster socket.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, user: str, cwd: str = "~", timeout: int = 60, port: int = 22, key_path: str = ""):
|
||||
def __init__(self, host: str, user: str, cwd: str = "~",
|
||||
timeout: int = 60, port: int = 22, key_path: str = ""):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self.host = host
|
||||
self.user = user
|
||||
@@ -64,12 +65,12 @@ class SSHEnvironment(BaseEnvironment):
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
|
||||
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command = self._prepare_command(command)
|
||||
wrapped = f"cd {work_dir} && {exec_command}"
|
||||
wrapped = f'cd {work_dir} && {exec_command}'
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
cmd = self._build_ssh_command()
|
||||
@@ -135,7 +136,8 @@ class SSHEnvironment(BaseEnvironment):
|
||||
def cleanup(self):
|
||||
if self.control_socket.exists():
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", "-O", "exit", f"{self.user}@{self.host}"]
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||
"-O", "exit", f"{self.user}@{self.host}"]
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (OSError, subprocess.SubprocessError):
|
||||
pass
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,10 +3,11 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from agent.redact import redact_sensitive_text
|
||||
from typing import Optional
|
||||
from tools.file_operations import ShellFileOperations
|
||||
from agent.redact import redact_sensitive_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,18 +25,13 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
||||
Thread-safe: uses the same per-task creation locks as terminal_tool to
|
||||
prevent duplicate sandbox creation from concurrent tool calls.
|
||||
"""
|
||||
import time
|
||||
|
||||
from tools.terminal_tool import (
|
||||
_active_environments,
|
||||
_create_environment,
|
||||
_creation_locks,
|
||||
_creation_locks_lock,
|
||||
_env_lock,
|
||||
_get_env_config,
|
||||
_last_activity,
|
||||
_start_cleanup_thread,
|
||||
_active_environments, _env_lock, _create_environment,
|
||||
_get_env_config, _last_activity, _start_cleanup_thread,
|
||||
_check_disk_usage_warning,
|
||||
_creation_locks, _creation_locks_lock,
|
||||
)
|
||||
import time
|
||||
|
||||
# Fast path: check cache -- but also verify the underlying environment
|
||||
# is still alive (it may have been killed by the cleanup thread).
|
||||
@@ -147,23 +143,17 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
||||
result = file_ops.write_file(path, content)
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
except Exception as e:
|
||||
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
|
||||
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def patch_tool(
|
||||
mode: str = "replace",
|
||||
path: str = None,
|
||||
old_string: str = None,
|
||||
new_string: str = None,
|
||||
replace_all: bool = False,
|
||||
patch: str = None,
|
||||
task_id: str = "default",
|
||||
) -> str:
|
||||
def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||
new_string: str = None, replace_all: bool = False, patch: str = None,
|
||||
task_id: str = "default") -> str:
|
||||
"""Patch a file using replace mode or V4A patch format."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
|
||||
|
||||
if mode == "replace":
|
||||
if not path:
|
||||
return json.dumps({"error": "path required"})
|
||||
@@ -176,7 +166,7 @@ def patch_tool(
|
||||
result = file_ops.patch_v4a(patch)
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown mode: {mode}"})
|
||||
|
||||
|
||||
result_dict = result.to_dict()
|
||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||||
# Hint when old_string not found — saves iterations where the agent
|
||||
@@ -188,33 +178,20 @@ def patch_tool(
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def search_tool(
|
||||
pattern: str,
|
||||
target: str = "content",
|
||||
path: str = ".",
|
||||
file_glob: str = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
output_mode: str = "content",
|
||||
context: int = 0,
|
||||
task_id: str = "default",
|
||||
) -> str:
|
||||
def search_tool(pattern: str, target: str = "content", path: str = ".",
|
||||
file_glob: str = None, limit: int = 50, offset: int = 0,
|
||||
output_mode: str = "content", context: int = 0,
|
||||
task_id: str = "default") -> str:
|
||||
"""Search for content or files."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.search(
|
||||
pattern=pattern,
|
||||
path=path,
|
||||
target=target,
|
||||
file_glob=file_glob,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
output_mode=output_mode,
|
||||
context=context,
|
||||
pattern=pattern, path=path, target=target, file_glob=file_glob,
|
||||
limit=limit, offset=offset, output_mode=output_mode, context=context
|
||||
)
|
||||
if hasattr(result, "matches"):
|
||||
if hasattr(result, 'matches'):
|
||||
for m in result.matches:
|
||||
if hasattr(m, "content") and m.content:
|
||||
if hasattr(m, 'content') and m.content:
|
||||
m.content = redact_sensitive_text(m.content)
|
||||
result_dict = result.to_dict()
|
||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||||
@@ -232,7 +209,7 @@ FILE_TOOLS = [
|
||||
{"name": "read_file", "function": read_file_tool},
|
||||
{"name": "write_file", "function": write_file_tool},
|
||||
{"name": "patch", "function": patch_tool},
|
||||
{"name": "search_files", "function": search_tool},
|
||||
{"name": "search_files", "function": search_tool}
|
||||
]
|
||||
|
||||
|
||||
@@ -250,10 +227,8 @@ from tools.registry import registry
|
||||
def _check_file_reqs():
|
||||
"""Lazy wrapper to avoid circular import with tools/__init__.py."""
|
||||
from tools import check_file_requirements
|
||||
|
||||
return check_file_requirements()
|
||||
|
||||
|
||||
READ_FILE_SCHEMA = {
|
||||
"name": "read_file",
|
||||
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. NOTE: Cannot read images or binary files — use vision_analyze for images.",
|
||||
@@ -261,21 +236,11 @@ READ_FILE_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the file to read (absolute, relative, or ~/path)"},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, default: 1)",
|
||||
"default": 1,
|
||||
"minimum": 1,
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (default: 500, max: 2000)",
|
||||
"default": 500,
|
||||
"maximum": 2000,
|
||||
},
|
||||
"offset": {"type": "integer", "description": "Line number to start reading from (1-indexed, default: 1)", "default": 1, "minimum": 1},
|
||||
"limit": {"type": "integer", "description": "Maximum number of lines to read (default: 500, max: 2000)", "default": 500, "maximum": 2000}
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
|
||||
WRITE_FILE_SCHEMA = {
|
||||
@@ -284,14 +249,11 @@ WRITE_FILE_SCHEMA = {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)",
|
||||
},
|
||||
"content": {"type": "string", "description": "Complete content to write to the file"},
|
||||
"path": {"type": "string", "description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)"},
|
||||
"content": {"type": "string", "description": "Complete content to write to the file"}
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
}
|
||||
|
||||
PATCH_SCHEMA = {
|
||||
@@ -300,33 +262,15 @@ PATCH_SCHEMA = {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["replace", "patch"],
|
||||
"description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches",
|
||||
"default": "replace",
|
||||
},
|
||||
"mode": {"type": "string", "enum": ["replace", "patch"], "description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches", "default": "replace"},
|
||||
"path": {"type": "string", "description": "File path to edit (required for 'replace' mode)"},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness.",
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text.",
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences instead of requiring a unique match (default: false)",
|
||||
"default": False,
|
||||
},
|
||||
"patch": {
|
||||
"type": "string",
|
||||
"description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch",
|
||||
},
|
||||
"old_string": {"type": "string", "description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness."},
|
||||
"new_string": {"type": "string", "description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text."},
|
||||
"replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match (default: false)", "default": False},
|
||||
"patch": {"type": "string", "description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch"}
|
||||
},
|
||||
"required": ["mode"],
|
||||
},
|
||||
"required": ["mode"]
|
||||
}
|
||||
}
|
||||
|
||||
SEARCH_FILES_SCHEMA = {
|
||||
@@ -335,57 +279,23 @@ SEARCH_FILES_SCHEMA = {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files"],
|
||||
"description": "'content' searches inside file contents, 'files' searches for files by name",
|
||||
"default": "content",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search in (default: current working directory)",
|
||||
"default": ".",
|
||||
},
|
||||
"file_glob": {
|
||||
"type": "string",
|
||||
"description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 50)",
|
||||
"default": 50,
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Skip first N results for pagination (default: 0)",
|
||||
"default": 0,
|
||||
},
|
||||
"output_mode": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files_only", "count"],
|
||||
"description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file",
|
||||
"default": "content",
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"description": "Number of context lines before and after each match (grep mode only)",
|
||||
"default": 0,
|
||||
},
|
||||
"pattern": {"type": "string", "description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search"},
|
||||
"target": {"type": "string", "enum": ["content", "files"], "description": "'content' searches inside file contents, 'files' searches for files by name", "default": "content"},
|
||||
"path": {"type": "string", "description": "Directory or file to search in (default: current working directory)", "default": "."},
|
||||
"file_glob": {"type": "string", "description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)"},
|
||||
"limit": {"type": "integer", "description": "Maximum number of results to return (default: 50)", "default": 50},
|
||||
"offset": {"type": "integer", "description": "Skip first N results for pagination (default: 0)", "default": 0},
|
||||
"output_mode": {"type": "string", "enum": ["content", "files_only", "count"], "description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file", "default": "content"},
|
||||
"context": {"type": "integer", "description": "Number of context lines before and after each match (grep mode only)", "default": 0}
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _handle_read_file(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
return read_file_tool(
|
||||
path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid
|
||||
)
|
||||
return read_file_tool(path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid)
|
||||
|
||||
|
||||
def _handle_write_file(args, **kw):
|
||||
@@ -396,14 +306,9 @@ def _handle_write_file(args, **kw):
|
||||
def _handle_patch(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
return patch_tool(
|
||||
mode=args.get("mode", "replace"),
|
||||
path=args.get("path"),
|
||||
old_string=args.get("old_string"),
|
||||
new_string=args.get("new_string"),
|
||||
replace_all=args.get("replace_all", False),
|
||||
patch=args.get("patch"),
|
||||
task_id=tid,
|
||||
)
|
||||
mode=args.get("mode", "replace"), path=args.get("path"),
|
||||
old_string=args.get("old_string"), new_string=args.get("new_string"),
|
||||
replace_all=args.get("replace_all", False), patch=args.get("patch"), task_id=tid)
|
||||
|
||||
|
||||
def _handle_search_files(args, **kw):
|
||||
@@ -412,29 +317,12 @@ def _handle_search_files(args, **kw):
|
||||
raw_target = args.get("target", "content")
|
||||
target = target_map.get(raw_target, raw_target)
|
||||
return search_tool(
|
||||
pattern=args.get("pattern", ""),
|
||||
target=target,
|
||||
path=args.get("path", "."),
|
||||
file_glob=args.get("file_glob"),
|
||||
limit=args.get("limit", 50),
|
||||
offset=args.get("offset", 0),
|
||||
output_mode=args.get("output_mode", "content"),
|
||||
context=args.get("context", 0),
|
||||
task_id=tid,
|
||||
)
|
||||
pattern=args.get("pattern", ""), target=target, path=args.get("path", "."),
|
||||
file_glob=args.get("file_glob"), limit=args.get("limit", 50), offset=args.get("offset", 0),
|
||||
output_mode=args.get("output_mode", "content"), context=args.get("context", 0), task_id=tid)
|
||||
|
||||
|
||||
registry.register(
|
||||
name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs
|
||||
)
|
||||
registry.register(
|
||||
name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs
|
||||
)
|
||||
registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs)
|
||||
registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs)
|
||||
registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs)
|
||||
registry.register(
|
||||
name="search_files",
|
||||
toolset="file",
|
||||
schema=SEARCH_FILES_SCHEMA,
|
||||
handler=_handle_search_files,
|
||||
check_fn=_check_file_reqs,
|
||||
)
|
||||
registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs)
|
||||
|
||||
@@ -19,7 +19,7 @@ The 9-strategy chain (inspired by OpenCode):
|
||||
|
||||
Usage:
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
|
||||
new_content, match_count, error = fuzzy_find_and_replace(
|
||||
content="def foo():\\n pass",
|
||||
old_string="def foo():",
|
||||
@@ -29,22 +29,21 @@ Usage:
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
def fuzzy_find_and_replace(
|
||||
content: str, old_string: str, new_string: str, replace_all: bool = False
|
||||
) -> tuple[str, int, str | None]:
|
||||
def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
||||
replace_all: bool = False) -> Tuple[str, int, Optional[str]]:
|
||||
"""
|
||||
Find and replace text using a chain of increasingly fuzzy matching strategies.
|
||||
|
||||
|
||||
Args:
|
||||
content: The file content to search in
|
||||
old_string: The text to find
|
||||
new_string: The replacement text
|
||||
replace_all: If True, replace all occurrences; if False, require uniqueness
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (new_content, match_count, error_message)
|
||||
- If successful: (modified_content, number_of_replacements, None)
|
||||
@@ -52,12 +51,12 @@ def fuzzy_find_and_replace(
|
||||
"""
|
||||
if not old_string:
|
||||
return content, 0, "old_string cannot be empty"
|
||||
|
||||
|
||||
if old_string == new_string:
|
||||
return content, 0, "old_string and new_string are identical"
|
||||
|
||||
|
||||
# Try each matching strategy in order
|
||||
strategies: list[tuple[str, Callable]] = [
|
||||
strategies: List[Tuple[str, Callable]] = [
|
||||
("exact", _strategy_exact),
|
||||
("line_trimmed", _strategy_line_trimmed),
|
||||
("whitespace_normalized", _strategy_whitespace_normalized),
|
||||
@@ -67,50 +66,46 @@ def fuzzy_find_and_replace(
|
||||
("block_anchor", _strategy_block_anchor),
|
||||
("context_aware", _strategy_context_aware),
|
||||
]
|
||||
|
||||
|
||||
for strategy_name, strategy_fn in strategies:
|
||||
matches = strategy_fn(content, old_string)
|
||||
|
||||
|
||||
if matches:
|
||||
# Found matches with this strategy
|
||||
if len(matches) > 1 and not replace_all:
|
||||
return (
|
||||
content,
|
||||
0,
|
||||
(
|
||||
f"Found {len(matches)} matches for old_string. "
|
||||
f"Provide more context to make it unique, or use replace_all=True."
|
||||
),
|
||||
return content, 0, (
|
||||
f"Found {len(matches)} matches for old_string. "
|
||||
f"Provide more context to make it unique, or use replace_all=True."
|
||||
)
|
||||
|
||||
|
||||
# Perform replacement
|
||||
new_content = _apply_replacements(content, matches, new_string)
|
||||
return new_content, len(matches), None
|
||||
|
||||
|
||||
# No strategy found a match
|
||||
return content, 0, "Could not find a match for old_string in the file"
|
||||
|
||||
|
||||
def _apply_replacements(content: str, matches: list[tuple[int, int]], new_string: str) -> str:
|
||||
def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str:
|
||||
"""
|
||||
Apply replacements at the given positions.
|
||||
|
||||
|
||||
Args:
|
||||
content: Original content
|
||||
matches: List of (start, end) positions to replace
|
||||
new_string: Replacement text
|
||||
|
||||
|
||||
Returns:
|
||||
Content with replacements applied
|
||||
"""
|
||||
# Sort matches by position (descending) to replace from end to start
|
||||
# This preserves positions of earlier matches
|
||||
sorted_matches = sorted(matches, key=lambda x: x[0], reverse=True)
|
||||
|
||||
|
||||
result = content
|
||||
for start, end in sorted_matches:
|
||||
result = result[:start] + new_string + result[end:]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -118,8 +113,7 @@ def _apply_replacements(content: str, matches: list[tuple[int, int]], new_string
|
||||
# Matching Strategies
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _strategy_exact(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""Strategy 1: Exact string match."""
|
||||
matches = []
|
||||
start = 0
|
||||
@@ -132,201 +126,206 @@ def _strategy_exact(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_line_trimmed(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
def _strategy_line_trimmed(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 2: Match with line-by-line whitespace trimming.
|
||||
|
||||
|
||||
Strips leading/trailing whitespace from each line before matching.
|
||||
"""
|
||||
# Normalize pattern and content by trimming each line
|
||||
pattern_lines = [line.strip() for line in pattern.split("\n")]
|
||||
pattern_normalized = "\n".join(pattern_lines)
|
||||
|
||||
content_lines = content.split("\n")
|
||||
pattern_lines = [line.strip() for line in pattern.split('\n')]
|
||||
pattern_normalized = '\n'.join(pattern_lines)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
content_normalized_lines = [line.strip() for line in content_lines]
|
||||
|
||||
|
||||
# Build mapping from normalized positions back to original positions
|
||||
return _find_normalized_matches(content, content_lines, content_normalized_lines, pattern, pattern_normalized)
|
||||
return _find_normalized_matches(
|
||||
content, content_lines, content_normalized_lines,
|
||||
pattern, pattern_normalized
|
||||
)
|
||||
|
||||
|
||||
def _strategy_whitespace_normalized(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
def _strategy_whitespace_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 3: Collapse multiple whitespace to single space.
|
||||
"""
|
||||
|
||||
def normalize(s):
|
||||
# Collapse multiple spaces/tabs to single space, preserve newlines
|
||||
return re.sub(r"[ \t]+", " ", s)
|
||||
|
||||
return re.sub(r'[ \t]+', ' ', s)
|
||||
|
||||
pattern_normalized = normalize(pattern)
|
||||
content_normalized = normalize(content)
|
||||
|
||||
|
||||
# Find in normalized, map back to original
|
||||
matches_in_normalized = _strategy_exact(content_normalized, pattern_normalized)
|
||||
|
||||
|
||||
if not matches_in_normalized:
|
||||
return []
|
||||
|
||||
|
||||
# Map positions back to original content
|
||||
return _map_normalized_positions(content, content_normalized, matches_in_normalized)
|
||||
|
||||
|
||||
def _strategy_indentation_flexible(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
def _strategy_indentation_flexible(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 4: Ignore indentation differences entirely.
|
||||
|
||||
|
||||
Strips all leading whitespace from lines before matching.
|
||||
"""
|
||||
|
||||
def strip_indent(s):
|
||||
return "\n".join(line.lstrip() for line in s.split("\n"))
|
||||
|
||||
return '\n'.join(line.lstrip() for line in s.split('\n'))
|
||||
|
||||
pattern_stripped = strip_indent(pattern)
|
||||
|
||||
content_lines = content.split("\n")
|
||||
|
||||
content_lines = content.split('\n')
|
||||
content_stripped_lines = [line.lstrip() for line in content_lines]
|
||||
pattern_lines = [line.lstrip() for line in pattern.split("\n")]
|
||||
|
||||
return _find_normalized_matches(content, content_lines, content_stripped_lines, pattern, "\n".join(pattern_lines))
|
||||
pattern_lines = [line.lstrip() for line in pattern.split('\n')]
|
||||
|
||||
return _find_normalized_matches(
|
||||
content, content_lines, content_stripped_lines,
|
||||
pattern, '\n'.join(pattern_lines)
|
||||
)
|
||||
|
||||
|
||||
def _strategy_escape_normalized(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
def _strategy_escape_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 5: Convert escape sequences to actual characters.
|
||||
|
||||
|
||||
Handles \\n -> newline, \\t -> tab, etc.
|
||||
"""
|
||||
|
||||
def unescape(s):
|
||||
# Convert common escape sequences
|
||||
return s.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
|
||||
|
||||
return s.replace('\\n', '\n').replace('\\t', '\t').replace('\\r', '\r')
|
||||
|
||||
pattern_unescaped = unescape(pattern)
|
||||
|
||||
|
||||
if pattern_unescaped == pattern:
|
||||
# No escapes to convert, skip this strategy
|
||||
return []
|
||||
|
||||
|
||||
return _strategy_exact(content, pattern_unescaped)
|
||||
|
||||
|
||||
def _strategy_trimmed_boundary(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 6: Trim whitespace from first and last lines only.
|
||||
|
||||
|
||||
Useful when the pattern boundaries have whitespace differences.
|
||||
"""
|
||||
pattern_lines = pattern.split("\n")
|
||||
pattern_lines = pattern.split('\n')
|
||||
if not pattern_lines:
|
||||
return []
|
||||
|
||||
|
||||
# Trim only first and last lines
|
||||
pattern_lines[0] = pattern_lines[0].strip()
|
||||
if len(pattern_lines) > 1:
|
||||
pattern_lines[-1] = pattern_lines[-1].strip()
|
||||
|
||||
modified_pattern = "\n".join(pattern_lines)
|
||||
|
||||
content_lines = content.split("\n")
|
||||
|
||||
|
||||
modified_pattern = '\n'.join(pattern_lines)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
|
||||
# Search through content for matching block
|
||||
matches = []
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
block_lines = content_lines[i : i + pattern_line_count]
|
||||
|
||||
block_lines = content_lines[i:i + pattern_line_count]
|
||||
|
||||
# Trim first and last of this block
|
||||
check_lines = block_lines.copy()
|
||||
check_lines[0] = check_lines[0].strip()
|
||||
if len(check_lines) > 1:
|
||||
check_lines[-1] = check_lines[-1].strip()
|
||||
|
||||
if "\n".join(check_lines) == modified_pattern:
|
||||
|
||||
if '\n'.join(check_lines) == modified_pattern:
|
||||
# Found match - calculate original positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_block_anchor(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 7: Match by anchoring on first and last lines.
|
||||
|
||||
|
||||
If first and last lines match exactly, accept middle with 70% similarity.
|
||||
"""
|
||||
pattern_lines = pattern.split("\n")
|
||||
pattern_lines = pattern.split('\n')
|
||||
if len(pattern_lines) < 2:
|
||||
return [] # Need at least 2 lines for anchoring
|
||||
|
||||
|
||||
first_line = pattern_lines[0].strip()
|
||||
last_line = pattern_lines[-1].strip()
|
||||
|
||||
content_lines = content.split("\n")
|
||||
|
||||
content_lines = content.split('\n')
|
||||
matches = []
|
||||
|
||||
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
# Check if first and last lines match
|
||||
if content_lines[i].strip() == first_line and content_lines[i + pattern_line_count - 1].strip() == last_line:
|
||||
if (content_lines[i].strip() == first_line and
|
||||
content_lines[i + pattern_line_count - 1].strip() == last_line):
|
||||
|
||||
# Check middle similarity
|
||||
if pattern_line_count <= 2:
|
||||
# Only first and last, they match
|
||||
similarity = 1.0
|
||||
else:
|
||||
content_middle = "\n".join(content_lines[i + 1 : i + pattern_line_count - 1])
|
||||
pattern_middle = "\n".join(pattern_lines[1:-1])
|
||||
content_middle = '\n'.join(content_lines[i+1:i+pattern_line_count-1])
|
||||
pattern_middle = '\n'.join(pattern_lines[1:-1])
|
||||
similarity = SequenceMatcher(None, content_middle, pattern_middle).ratio()
|
||||
|
||||
|
||||
if similarity >= 0.70:
|
||||
# Calculate positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_context_aware(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 8: Line-by-line similarity with 50% threshold.
|
||||
|
||||
|
||||
Finds blocks where at least 50% of lines have high similarity.
|
||||
"""
|
||||
pattern_lines = pattern.split("\n")
|
||||
content_lines = content.split("\n")
|
||||
|
||||
pattern_lines = pattern.split('\n')
|
||||
content_lines = content.split('\n')
|
||||
|
||||
if not pattern_lines:
|
||||
return []
|
||||
|
||||
|
||||
matches = []
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
block_lines = content_lines[i : i + pattern_line_count]
|
||||
|
||||
block_lines = content_lines[i:i + pattern_line_count]
|
||||
|
||||
# Calculate line-by-line similarity
|
||||
high_similarity_count = 0
|
||||
for p_line, c_line in zip(pattern_lines, block_lines):
|
||||
sim = SequenceMatcher(None, p_line.strip(), c_line.strip()).ratio()
|
||||
if sim >= 0.80:
|
||||
high_similarity_count += 1
|
||||
|
||||
|
||||
# Need at least 50% of lines to have high similarity
|
||||
if high_similarity_count >= len(pattern_lines) * 0.5:
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
@@ -334,76 +333,74 @@ def _strategy_context_aware(content: str, pattern: str) -> list[tuple[int, int]]
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _find_normalized_matches(
|
||||
content: str, content_lines: list[str], content_normalized_lines: list[str], pattern: str, pattern_normalized: str
|
||||
) -> list[tuple[int, int]]:
|
||||
def _find_normalized_matches(content: str, content_lines: List[str],
|
||||
content_normalized_lines: List[str],
|
||||
pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Find matches in normalized content and map back to original positions.
|
||||
|
||||
|
||||
Args:
|
||||
content: Original content string
|
||||
content_lines: Original content split by lines
|
||||
content_normalized_lines: Normalized content lines
|
||||
pattern: Original pattern
|
||||
pattern_normalized: Normalized pattern
|
||||
|
||||
|
||||
Returns:
|
||||
List of (start, end) positions in the original content
|
||||
"""
|
||||
pattern_norm_lines = pattern_normalized.split("\n")
|
||||
pattern_norm_lines = pattern_normalized.split('\n')
|
||||
num_pattern_lines = len(pattern_norm_lines)
|
||||
|
||||
|
||||
matches = []
|
||||
|
||||
|
||||
for i in range(len(content_normalized_lines) - num_pattern_lines + 1):
|
||||
# Check if this block matches
|
||||
block = "\n".join(content_normalized_lines[i : i + num_pattern_lines])
|
||||
|
||||
block = '\n'.join(content_normalized_lines[i:i + num_pattern_lines])
|
||||
|
||||
if block == pattern_normalized:
|
||||
# Found a match - calculate original positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + num_pattern_lines]) - 1
|
||||
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1
|
||||
|
||||
# Handle case where end is past content
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
|
||||
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _map_normalized_positions(
|
||||
original: str, normalized: str, normalized_matches: list[tuple[int, int]]
|
||||
) -> list[tuple[int, int]]:
|
||||
def _map_normalized_positions(original: str, normalized: str,
|
||||
normalized_matches: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Map positions from normalized string back to original.
|
||||
|
||||
|
||||
This is a best-effort mapping that works for whitespace normalization.
|
||||
"""
|
||||
if not normalized_matches:
|
||||
return []
|
||||
|
||||
|
||||
# Build character mapping from normalized to original
|
||||
orig_to_norm = [] # orig_to_norm[i] = position in normalized
|
||||
|
||||
|
||||
orig_idx = 0
|
||||
norm_idx = 0
|
||||
|
||||
|
||||
while orig_idx < len(original) and norm_idx < len(normalized):
|
||||
if original[orig_idx] == normalized[norm_idx]:
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
norm_idx += 1
|
||||
elif original[orig_idx] in " \t" and normalized[norm_idx] == " ":
|
||||
elif original[orig_idx] in ' \t' and normalized[norm_idx] == ' ':
|
||||
# Original has space/tab, normalized collapsed to space
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
# Don't advance norm_idx yet - wait until all whitespace consumed
|
||||
if orig_idx < len(original) and original[orig_idx] not in " \t":
|
||||
if orig_idx < len(original) and original[orig_idx] not in ' \t':
|
||||
norm_idx += 1
|
||||
elif original[orig_idx] in " \t":
|
||||
elif original[orig_idx] in ' \t':
|
||||
# Extra whitespace in original
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
@@ -411,21 +408,21 @@ def _map_normalized_positions(
|
||||
# Mismatch - shouldn't happen with our normalization
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
|
||||
|
||||
# Fill remaining
|
||||
while orig_idx < len(original):
|
||||
orig_to_norm.append(len(normalized))
|
||||
orig_idx += 1
|
||||
|
||||
|
||||
# Reverse mapping: for each normalized position, find original range
|
||||
norm_to_orig_start = {}
|
||||
norm_to_orig_end = {}
|
||||
|
||||
|
||||
for orig_pos, norm_pos in enumerate(orig_to_norm):
|
||||
if norm_pos not in norm_to_orig_start:
|
||||
norm_to_orig_start[norm_pos] = orig_pos
|
||||
norm_to_orig_end[norm_pos] = orig_pos
|
||||
|
||||
|
||||
# Map matches
|
||||
original_matches = []
|
||||
for norm_start, norm_end in normalized_matches:
|
||||
@@ -435,17 +432,17 @@ def _map_normalized_positions(
|
||||
else:
|
||||
# Find nearest
|
||||
orig_start = min(i for i, n in enumerate(orig_to_norm) if n >= norm_start)
|
||||
|
||||
|
||||
# Find original end
|
||||
if norm_end - 1 in norm_to_orig_end:
|
||||
orig_end = norm_to_orig_end[norm_end - 1] + 1
|
||||
else:
|
||||
orig_end = orig_start + (norm_end - norm_start)
|
||||
|
||||
|
||||
# Expand to include trailing whitespace that was normalized
|
||||
while orig_end < len(original) and original[orig_end] in " \t":
|
||||
while orig_end < len(original) and original[orig_end] in ' \t':
|
||||
orig_end += 1
|
||||
|
||||
|
||||
original_matches.append((orig_start, min(orig_end, len(original))))
|
||||
|
||||
|
||||
return original_matches
|
||||
|
||||
@@ -15,7 +15,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,26 +35,23 @@ def _get_config():
|
||||
_HASS_TOKEN or os.getenv("HASS_TOKEN", ""),
|
||||
)
|
||||
|
||||
|
||||
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
|
||||
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
|
||||
|
||||
# Service domains blocked for security -- these allow arbitrary code/command
|
||||
# execution on the HA host or enable SSRF attacks on the local network.
|
||||
# HA provides zero service-level access control; all safety must be in our layer.
|
||||
_BLOCKED_DOMAINS = frozenset(
|
||||
{
|
||||
"shell_command", # arbitrary shell commands as root in HA container
|
||||
"command_line", # sensors/switches that execute shell commands
|
||||
"python_script", # sandboxed but can escalate via hass.services.call()
|
||||
"pyscript", # scripting integration with broader access
|
||||
"hassio", # addon control, host shutdown/reboot, stdin to containers
|
||||
"rest_command", # HTTP requests from HA server (SSRF vector)
|
||||
}
|
||||
)
|
||||
_BLOCKED_DOMAINS = frozenset({
|
||||
"shell_command", # arbitrary shell commands as root in HA container
|
||||
"command_line", # sensors/switches that execute shell commands
|
||||
"python_script", # sandboxed but can escalate via hass.services.call()
|
||||
"pyscript", # scripting integration with broader access
|
||||
"hassio", # addon control, host shutdown/reboot, stdin to containers
|
||||
"rest_command", # HTTP requests from HA server (SSRF vector)
|
||||
})
|
||||
|
||||
|
||||
def _get_headers(token: str = "") -> dict[str, str]:
|
||||
def _get_headers(token: str = "") -> Dict[str, str]:
|
||||
"""Return authorization headers for HA REST API."""
|
||||
if not token:
|
||||
_, token = _get_config()
|
||||
@@ -68,12 +65,11 @@ def _get_headers(token: str = "") -> dict[str, str]:
|
||||
# Async helpers (called from sync handlers via run_until_complete)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _filter_and_summarize(
|
||||
states: list,
|
||||
domain: str | None = None,
|
||||
area: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
domain: Optional[str] = None,
|
||||
area: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Filter raw HA states by domain/area and return a compact summary."""
|
||||
if domain:
|
||||
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
|
||||
@@ -81,29 +77,26 @@ def _filter_and_summarize(
|
||||
if area:
|
||||
area_lower = area.lower()
|
||||
states = [
|
||||
s
|
||||
for s in states
|
||||
s for s in states
|
||||
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
|
||||
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
|
||||
]
|
||||
|
||||
entities = []
|
||||
for s in states:
|
||||
entities.append(
|
||||
{
|
||||
"entity_id": s["entity_id"],
|
||||
"state": s["state"],
|
||||
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
|
||||
}
|
||||
)
|
||||
entities.append({
|
||||
"entity_id": s["entity_id"],
|
||||
"state": s["state"],
|
||||
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
|
||||
})
|
||||
|
||||
return {"count": len(entities), "entities": entities}
|
||||
|
||||
|
||||
async def _async_list_entities(
|
||||
domain: str | None = None,
|
||||
area: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
domain: Optional[str] = None,
|
||||
area: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch entity states from HA and optionally filter by domain/area."""
|
||||
import aiohttp
|
||||
|
||||
@@ -117,7 +110,7 @@ async def _async_list_entities(
|
||||
return _filter_and_summarize(states, domain, area)
|
||||
|
||||
|
||||
async def _async_get_state(entity_id: str) -> dict[str, Any]:
|
||||
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
|
||||
"""Fetch detailed state of a single entity."""
|
||||
import aiohttp
|
||||
|
||||
@@ -138,11 +131,11 @@ async def _async_get_state(entity_id: str) -> dict[str, Any]:
|
||||
|
||||
|
||||
def _build_service_payload(
|
||||
entity_id: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
entity_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the JSON payload for a HA service call."""
|
||||
payload: dict[str, Any] = {}
|
||||
payload: Dict[str, Any] = {}
|
||||
if data:
|
||||
payload.update(data)
|
||||
# entity_id parameter takes precedence over data["entity_id"]
|
||||
@@ -155,17 +148,15 @@ def _parse_service_response(
|
||||
domain: str,
|
||||
service: str,
|
||||
result: Any,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""Parse HA service call response into a structured result."""
|
||||
affected = []
|
||||
if isinstance(result, list):
|
||||
for s in result:
|
||||
affected.append(
|
||||
{
|
||||
"entity_id": s.get("entity_id", ""),
|
||||
"state": s.get("state", ""),
|
||||
}
|
||||
)
|
||||
affected.append({
|
||||
"entity_id": s.get("entity_id", ""),
|
||||
"state": s.get("state", ""),
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -177,9 +168,9 @@ def _parse_service_response(
|
||||
async def _async_call_service(
|
||||
domain: str,
|
||||
service: str,
|
||||
entity_id: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
entity_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Call a Home Assistant service."""
|
||||
import aiohttp
|
||||
|
||||
@@ -187,17 +178,15 @@ async def _async_call_service(
|
||||
url = f"{hass_url}/api/services/{domain}/{service}"
|
||||
payload = _build_service_payload(entity_id, data)
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
headers=_get_headers(hass_token),
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp,
|
||||
):
|
||||
resp.raise_for_status()
|
||||
result = await resp.json()
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
result = await resp.json()
|
||||
|
||||
return _parse_service_response(domain, service, result)
|
||||
|
||||
@@ -206,7 +195,6 @@ async def _async_call_service(
|
||||
# Sync wrappers (handler signature: (args, **kw) -> str)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from a sync handler."""
|
||||
try:
|
||||
@@ -217,7 +205,6 @@ def _run_async(coro):
|
||||
if loop and loop.is_running():
|
||||
# Already inside an event loop -- create a new thread
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, coro)
|
||||
return future.result(timeout=30)
|
||||
@@ -260,12 +247,10 @@ def _handle_call_service(args: dict, **kw) -> str:
|
||||
return json.dumps({"error": "Missing required parameters: domain and service"})
|
||||
|
||||
if domain in _BLOCKED_DOMAINS:
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"Service domain '{domain}' is blocked for security. "
|
||||
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
|
||||
}
|
||||
)
|
||||
return json.dumps({
|
||||
"error": f"Service domain '{domain}' is blocked for security. "
|
||||
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
|
||||
})
|
||||
|
||||
entity_id = args.get("entity_id")
|
||||
if entity_id and not _ENTITY_ID_RE.match(entity_id):
|
||||
@@ -284,8 +269,7 @@ def _handle_call_service(args: dict, **kw) -> str:
|
||||
# List services
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _async_list_services(domain: str | None = None) -> dict[str, Any]:
|
||||
async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Fetch available services from HA and optionally filter by domain."""
|
||||
import aiohttp
|
||||
|
||||
@@ -306,10 +290,13 @@ async def _async_list_services(domain: str | None = None) -> dict[str, Any]:
|
||||
d = svc_domain.get("domain", "")
|
||||
domain_services = {}
|
||||
for svc_name, svc_info in svc_domain.get("services", {}).items():
|
||||
svc_entry: dict[str, Any] = {"description": svc_info.get("description", "")}
|
||||
svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")}
|
||||
fields = svc_info.get("fields", {})
|
||||
if fields:
|
||||
svc_entry["fields"] = {k: v.get("description", "") for k, v in fields.items() if isinstance(v, dict)}
|
||||
svc_entry["fields"] = {
|
||||
k: v.get("description", "") for k, v in fields.items()
|
||||
if isinstance(v, dict)
|
||||
}
|
||||
domain_services[svc_name] = svc_entry
|
||||
result.append({"domain": d, "services": domain_services})
|
||||
|
||||
@@ -331,7 +318,6 @@ def _handle_list_services(args: dict, **kw) -> str:
|
||||
# Availability check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_ha_available() -> bool:
|
||||
"""Tool is only available when HASS_TOKEN is set."""
|
||||
return bool(os.getenv("HASS_TOKEN"))
|
||||
@@ -383,7 +369,8 @@ HA_GET_STATE_SCHEMA = {
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The entity ID to query (e.g. 'light.living_room', 'climate.thermostat', 'sensor.temperature')."
|
||||
"The entity ID to query (e.g. 'light.living_room', "
|
||||
"'climate.thermostat', 'sensor.temperature')."
|
||||
),
|
||||
},
|
||||
},
|
||||
@@ -405,7 +392,8 @@ HA_LIST_SERVICES_SCHEMA = {
|
||||
"domain": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Filter by domain (e.g. 'light', 'climate', 'switch'). Omit to list services for all domains."
|
||||
"Filter by domain (e.g. 'light', 'climate', 'switch'). "
|
||||
"Omit to list services for all domains."
|
||||
),
|
||||
},
|
||||
},
|
||||
@@ -440,7 +428,8 @@ HA_CALL_SERVICE_SCHEMA = {
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Target entity ID (e.g. 'light.living_room'). Some services (like scene.turn_on) may not need this."
|
||||
"Target entity ID (e.g. 'light.living_room'). "
|
||||
"Some services (like scene.turn_on) may not need this."
|
||||
),
|
||||
},
|
||||
"data": {
|
||||
|
||||
@@ -65,7 +65,6 @@ HONCHO_TOOL_SCHEMA = {
|
||||
|
||||
# ── Tool handler ──
|
||||
|
||||
|
||||
def _handle_query_user_context(args: dict, **kw) -> str:
|
||||
"""Execute the Honcho context query."""
|
||||
query = args.get("query", "")
|
||||
@@ -85,7 +84,6 @@ def _handle_query_user_context(args: dict, **kw) -> str:
|
||||
|
||||
# ── Availability check ──
|
||||
|
||||
|
||||
def _check_honcho_available() -> bool:
|
||||
"""Tool is only available when Honcho is active."""
|
||||
return _session_manager is not None and _session_key is not None
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""
|
||||
Image Generation Tools Module
|
||||
|
||||
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
|
||||
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
|
||||
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
|
||||
|
||||
Available tools:
|
||||
@@ -19,7 +19,7 @@ Features:
|
||||
Usage:
|
||||
from image_generation_tool import image_generate_tool
|
||||
import asyncio
|
||||
|
||||
|
||||
# Generate and automatically upscale an image
|
||||
result = await image_generate_tool(
|
||||
prompt="A serene mountain landscape with cherry blossoms",
|
||||
@@ -28,14 +28,12 @@ Usage:
|
||||
)
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import datetime
|
||||
from typing import Dict, Any, Optional, Union
|
||||
import fal_client
|
||||
|
||||
from tools.debug_helpers import DebugSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -53,7 +51,11 @@ ENABLE_SAFETY_CHECKER = False
|
||||
SAFETY_TOLERANCE = "5" # Maximum tolerance (1-5, where 5 is most permissive)
|
||||
|
||||
# Aspect ratio mapping - simplified choices for model to select
|
||||
ASPECT_RATIO_MAP = {"landscape": "landscape_16_9", "square": "square_hd", "portrait": "portrait_16_9"}
|
||||
ASPECT_RATIO_MAP = {
|
||||
"landscape": "landscape_16_9",
|
||||
"square": "square_hd",
|
||||
"portrait": "portrait_16_9"
|
||||
}
|
||||
VALID_ASPECT_RATIOS = list(ASPECT_RATIO_MAP.keys())
|
||||
|
||||
# Configuration for automatic upscaling
|
||||
@@ -68,7 +70,9 @@ UPSCALER_GUIDANCE_SCALE = 4
|
||||
UPSCALER_NUM_INFERENCE_STEPS = 18
|
||||
|
||||
# Valid parameter values for validation based on FLUX 2 Pro documentation
|
||||
VALID_IMAGE_SIZES = ["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"]
|
||||
VALID_IMAGE_SIZES = [
|
||||
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
|
||||
]
|
||||
VALID_OUTPUT_FORMATS = ["jpeg", "png"]
|
||||
VALID_ACCELERATION_MODES = ["none", "regular", "high"]
|
||||
|
||||
@@ -76,16 +80,16 @@ _debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG")
|
||||
|
||||
|
||||
def _validate_parameters(
|
||||
image_size: str | dict[str, int],
|
||||
image_size: Union[str, Dict[str, int]],
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
num_images: int,
|
||||
output_format: str,
|
||||
acceleration: str = "none",
|
||||
) -> dict[str, Any]:
|
||||
acceleration: str = "none"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate and normalize image generation parameters for FLUX 2 Pro model.
|
||||
|
||||
|
||||
Args:
|
||||
image_size: Either a preset string or custom size dict
|
||||
num_inference_steps: Number of inference steps
|
||||
@@ -93,15 +97,15 @@ def _validate_parameters(
|
||||
num_images: Number of images to generate
|
||||
output_format: Output format for images
|
||||
acceleration: Acceleration mode for generation speed
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Validated and normalized parameters
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameter is invalid
|
||||
"""
|
||||
validated = {}
|
||||
|
||||
|
||||
# Validate image_size
|
||||
if isinstance(image_size, str):
|
||||
if image_size not in VALID_IMAGE_SIZES:
|
||||
@@ -119,52 +123,52 @@ def _validate_parameters(
|
||||
validated["image_size"] = image_size
|
||||
else:
|
||||
raise ValueError("image_size must be either a preset string or a dict with width/height")
|
||||
|
||||
|
||||
# Validate num_inference_steps
|
||||
if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100:
|
||||
raise ValueError("num_inference_steps must be an integer between 1 and 100")
|
||||
validated["num_inference_steps"] = num_inference_steps
|
||||
|
||||
|
||||
# Validate guidance_scale (FLUX 2 Pro default is 4.5)
|
||||
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
|
||||
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
|
||||
validated["guidance_scale"] = float(guidance_scale)
|
||||
|
||||
|
||||
# Validate num_images
|
||||
if not isinstance(num_images, int) or num_images < 1 or num_images > 4:
|
||||
raise ValueError("num_images must be an integer between 1 and 4")
|
||||
validated["num_images"] = num_images
|
||||
|
||||
|
||||
# Validate output_format
|
||||
if output_format not in VALID_OUTPUT_FORMATS:
|
||||
raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}")
|
||||
validated["output_format"] = output_format
|
||||
|
||||
|
||||
# Validate acceleration
|
||||
if acceleration not in VALID_ACCELERATION_MODES:
|
||||
raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}")
|
||||
validated["acceleration"] = acceleration
|
||||
|
||||
|
||||
return validated
|
||||
|
||||
|
||||
def _upscale_image(image_url: str, original_prompt: str) -> dict[str, Any]:
|
||||
def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Upscale an image using FAL.ai's Clarity Upscaler.
|
||||
|
||||
|
||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues
|
||||
when called from threaded contexts (e.g. gateway thread pool).
|
||||
|
||||
|
||||
Args:
|
||||
image_url (str): URL of the image to upscale
|
||||
original_prompt (str): Original prompt used to generate the image
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Upscaled image data or None if upscaling fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Upscaling image with Clarity Upscaler...")
|
||||
|
||||
|
||||
# Prepare arguments for upscaler
|
||||
upscaler_arguments = {
|
||||
"image_url": image_url,
|
||||
@@ -175,36 +179,35 @@ def _upscale_image(image_url: str, original_prompt: str) -> dict[str, Any]:
|
||||
"resemblance": UPSCALER_RESEMBLANCE,
|
||||
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
|
||||
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
|
||||
"enable_safety_checker": UPSCALER_SAFETY_CHECKER,
|
||||
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
|
||||
}
|
||||
|
||||
|
||||
# Use sync API — fal_client.submit() uses httpx.Client (no event loop).
|
||||
# The async API (submit_async) caches a global httpx.AsyncClient via
|
||||
# @cached_property, which breaks when asyncio.run() destroys the loop
|
||||
# between calls (gateway thread-pool pattern).
|
||||
handler = fal_client.submit(UPSCALER_MODEL, arguments=upscaler_arguments)
|
||||
|
||||
handler = fal_client.submit(
|
||||
UPSCALER_MODEL,
|
||||
arguments=upscaler_arguments
|
||||
)
|
||||
|
||||
# Get the upscaled result (sync — blocks until done)
|
||||
result = handler.get()
|
||||
|
||||
|
||||
if result and "image" in result:
|
||||
upscaled_image = result["image"]
|
||||
logger.info(
|
||||
"Image upscaled successfully to %sx%s",
|
||||
upscaled_image.get("width", "unknown"),
|
||||
upscaled_image.get("height", "unknown"),
|
||||
)
|
||||
logger.info("Image upscaled successfully to %sx%s", upscaled_image.get('width', 'unknown'), upscaled_image.get('height', 'unknown'))
|
||||
return {
|
||||
"url": upscaled_image["url"],
|
||||
"width": upscaled_image.get("width", 0),
|
||||
"height": upscaled_image.get("height", 0),
|
||||
"upscaled": True,
|
||||
"upscale_factor": UPSCALER_FACTOR,
|
||||
"upscale_factor": UPSCALER_FACTOR
|
||||
}
|
||||
else:
|
||||
logger.error("Upscaler returned invalid response")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error upscaling image: %s", e)
|
||||
return None
|
||||
@@ -217,16 +220,16 @@ def image_generate_tool(
|
||||
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
|
||||
num_images: int = DEFAULT_NUM_IMAGES,
|
||||
output_format: str = DEFAULT_OUTPUT_FORMAT,
|
||||
seed: int | None = None,
|
||||
seed: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling.
|
||||
|
||||
|
||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues.
|
||||
The async API's global httpx.AsyncClient (cached via @cached_property) breaks
|
||||
when asyncio.run() destroys and recreates event loops between calls, which
|
||||
happens in the gateway's thread-pool pattern.
|
||||
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt describing the desired image
|
||||
aspect_ratio (str): Image aspect ratio - "landscape", "square", or "portrait" (default: "landscape")
|
||||
@@ -235,7 +238,7 @@ def image_generate_tool(
|
||||
num_images (int): Number of images to generate (1-4, default: 1)
|
||||
output_format (str): Image format "jpeg" or "png" (default: "png")
|
||||
seed (Optional[int]): Random seed for reproducible results (optional)
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON string containing minimal generation results:
|
||||
{
|
||||
@@ -249,7 +252,7 @@ def image_generate_tool(
|
||||
logger.warning("Invalid aspect_ratio '%s', defaulting to '%s'", aspect_ratio, DEFAULT_ASPECT_RATIO)
|
||||
aspect_ratio_lower = DEFAULT_ASPECT_RATIO
|
||||
image_size = ASPECT_RATIO_MAP[aspect_ratio_lower]
|
||||
|
||||
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"prompt": prompt,
|
||||
@@ -259,32 +262,32 @@ def image_generate_tool(
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_images": num_images,
|
||||
"output_format": output_format,
|
||||
"seed": seed,
|
||||
"seed": seed
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"images_generated": 0,
|
||||
"generation_time": 0,
|
||||
"generation_time": 0
|
||||
}
|
||||
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Generating %s image(s) with FLUX 2 Pro: %s", num_images, prompt[:80])
|
||||
|
||||
|
||||
# Validate prompt
|
||||
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
|
||||
raise ValueError("Prompt is required and must be a non-empty string")
|
||||
|
||||
|
||||
# Check API key availability
|
||||
if not os.getenv("FAL_KEY"):
|
||||
raise ValueError("FAL_KEY environment variable not set")
|
||||
|
||||
|
||||
# Validate other parameters
|
||||
validated_params = _validate_parameters(
|
||||
image_size, num_inference_steps, guidance_scale, num_images, output_format, "none"
|
||||
)
|
||||
|
||||
|
||||
# Prepare arguments for FAL.ai FLUX 2 Pro API
|
||||
arguments = {
|
||||
"prompt": prompt.strip(),
|
||||
@@ -295,44 +298,51 @@ def image_generate_tool(
|
||||
"output_format": validated_params["output_format"],
|
||||
"enable_safety_checker": ENABLE_SAFETY_CHECKER,
|
||||
"safety_tolerance": SAFETY_TOLERANCE,
|
||||
"sync_mode": True, # Use sync mode for immediate results
|
||||
"sync_mode": True # Use sync mode for immediate results
|
||||
}
|
||||
|
||||
|
||||
# Add seed if provided
|
||||
if seed is not None and isinstance(seed, int):
|
||||
arguments["seed"] = seed
|
||||
|
||||
|
||||
logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...")
|
||||
logger.info(" Model: %s", DEFAULT_MODEL)
|
||||
logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size)
|
||||
logger.info(" Steps: %s", validated_params["num_inference_steps"])
|
||||
logger.info(" Guidance: %s", validated_params["guidance_scale"])
|
||||
|
||||
logger.info(" Steps: %s", validated_params['num_inference_steps'])
|
||||
logger.info(" Guidance: %s", validated_params['guidance_scale'])
|
||||
|
||||
# Submit request to FAL.ai using sync API (avoids cached event loop issues)
|
||||
handler = fal_client.submit(DEFAULT_MODEL, arguments=arguments)
|
||||
|
||||
handler = fal_client.submit(
|
||||
DEFAULT_MODEL,
|
||||
arguments=arguments
|
||||
)
|
||||
|
||||
# Get the result (sync — blocks until done)
|
||||
result = handler.get()
|
||||
|
||||
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
|
||||
|
||||
# Process the response
|
||||
if not result or "images" not in result:
|
||||
raise ValueError("Invalid response from FAL.ai API - no images returned")
|
||||
|
||||
|
||||
images = result.get("images", [])
|
||||
if not images:
|
||||
raise ValueError("No images were generated")
|
||||
|
||||
|
||||
# Format image data and upscale images
|
||||
formatted_images = []
|
||||
for img in images:
|
||||
if isinstance(img, dict) and "url" in img:
|
||||
original_image = {"url": img["url"], "width": img.get("width", 0), "height": img.get("height", 0)}
|
||||
|
||||
original_image = {
|
||||
"url": img["url"],
|
||||
"width": img.get("width", 0),
|
||||
"height": img.get("height", 0)
|
||||
}
|
||||
|
||||
# Attempt to upscale the image
|
||||
upscaled_image = _upscale_image(img["url"], prompt.strip())
|
||||
|
||||
|
||||
if upscaled_image:
|
||||
# Use upscaled image if successful
|
||||
formatted_images.append(upscaled_image)
|
||||
@@ -341,48 +351,52 @@ def image_generate_tool(
|
||||
logger.warning("Using original image as fallback")
|
||||
original_image["upscaled"] = False
|
||||
formatted_images.append(original_image)
|
||||
|
||||
|
||||
if not formatted_images:
|
||||
raise ValueError("No valid image URLs returned from API")
|
||||
|
||||
|
||||
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False))
|
||||
logger.info(
|
||||
"Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count
|
||||
)
|
||||
|
||||
logger.info("Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count)
|
||||
|
||||
# Prepare successful response - minimal format
|
||||
response_data = {"success": True, "image": formatted_images[0]["url"] if formatted_images else None}
|
||||
|
||||
response_data = {
|
||||
"success": True,
|
||||
"image": formatted_images[0]["url"] if formatted_images else None
|
||||
}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["images_generated"] = len(formatted_images)
|
||||
debug_call_data["generation_time"] = generation_time
|
||||
|
||||
|
||||
# Log debug information
|
||||
_debug.log_call("image_generate_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
error_msg = f"Error generating image: {str(e)}"
|
||||
logger.error("%s", error_msg)
|
||||
|
||||
|
||||
# Prepare error response - minimal format
|
||||
response_data = {"success": False, "image": None}
|
||||
|
||||
response_data = {
|
||||
"success": False,
|
||||
"image": None
|
||||
}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
debug_call_data["generation_time"] = generation_time
|
||||
_debug.log_call("image_generate_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_fal_api_key() -> bool:
|
||||
"""
|
||||
Check if the FAL.ai API key is available in environment variables.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
@@ -392,7 +406,7 @@ def check_fal_api_key() -> bool:
|
||||
def check_image_generation_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for image generation tools are met.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
@@ -400,20 +414,19 @@ def check_image_generation_requirements() -> bool:
|
||||
# Check API key
|
||||
if not check_fal_api_key():
|
||||
return False
|
||||
|
||||
|
||||
# Check if fal_client is available
|
||||
import fal_client
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_debug_session_info() -> dict[str, Any]:
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
@@ -426,10 +439,10 @@ if __name__ == "__main__":
|
||||
"""
|
||||
print("🎨 Image Generation Tools Module - FLUX 2 Pro + Auto Upscaling")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_fal_api_key()
|
||||
|
||||
|
||||
if not api_available:
|
||||
print("❌ FAL_KEY environment variable not set")
|
||||
print("Please set your API key: export FAL_KEY='your-key-here'")
|
||||
@@ -437,28 +450,27 @@ if __name__ == "__main__":
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ FAL.ai API key found")
|
||||
|
||||
|
||||
# Check if fal_client is available
|
||||
try:
|
||||
import fal_client
|
||||
|
||||
print("✅ fal_client library available")
|
||||
except ImportError:
|
||||
print("❌ fal_client library not found")
|
||||
print("Please install: pip install fal-client")
|
||||
exit(1)
|
||||
|
||||
|
||||
print("🛠️ Image generation tools ready for use!")
|
||||
print(f"🤖 Using model: {DEFAULT_MODEL}")
|
||||
print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)")
|
||||
|
||||
|
||||
# Show debug mode status
|
||||
if _debug.active:
|
||||
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
||||
print(f" Debug logs will be saved to: ./logs/image_tools_debug_{_debug.session_id}.json")
|
||||
else:
|
||||
print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from image_generation_tool import image_generate_tool")
|
||||
print(" import asyncio")
|
||||
@@ -472,23 +484,23 @@ if __name__ == "__main__":
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
|
||||
print("\nSupported image sizes:")
|
||||
for size in VALID_IMAGE_SIZES:
|
||||
print(f" - {size}")
|
||||
print(" - Custom: {'width': 512, 'height': 768} (if needed)")
|
||||
|
||||
|
||||
print("\nAcceleration modes:")
|
||||
for mode in VALID_ACCELERATION_MODES:
|
||||
print(f" - {mode}")
|
||||
|
||||
|
||||
print("\nExample prompts:")
|
||||
print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'")
|
||||
print(" - 'Modern architecture building with glass facade, sunset lighting'")
|
||||
print(" - 'Abstract art with vibrant colors and geometric patterns'")
|
||||
print(" - 'Portrait of a wise old owl perched on ancient tree branch'")
|
||||
print(" - 'Futuristic cityscape with flying cars and neon lights'")
|
||||
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export IMAGE_TOOLS_DEBUG=true")
|
||||
@@ -509,17 +521,17 @@ IMAGE_GENERATE_SCHEMA = {
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The text prompt describing the desired image. Be detailed and descriptive.",
|
||||
"description": "The text prompt describing the desired image. Be detailed and descriptive."
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["landscape", "square", "portrait"],
|
||||
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
|
||||
"default": "landscape",
|
||||
},
|
||||
"default": "landscape"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"],
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -91,11 +91,9 @@ _MCP_SAMPLING_TYPES = False
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
_MCP_AVAILABLE = True
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
_MCP_HTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
@@ -110,7 +108,6 @@ try:
|
||||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
|
||||
_MCP_SAMPLING_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP sampling types not available -- sampling disabled")
|
||||
@@ -121,36 +118,27 @@ except ImportError:
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
|
||||
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
|
||||
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
|
||||
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
|
||||
_MAX_RECONNECT_RETRIES = 5
|
||||
_MAX_BACKOFF_SECONDS = 60
|
||||
|
||||
# Environment variables that are safe to pass to stdio subprocesses
|
||||
_SAFE_ENV_KEYS = frozenset(
|
||||
{
|
||||
"PATH",
|
||||
"HOME",
|
||||
"USER",
|
||||
"LANG",
|
||||
"LC_ALL",
|
||||
"TERM",
|
||||
"SHELL",
|
||||
"TMPDIR",
|
||||
}
|
||||
)
|
||||
_SAFE_ENV_KEYS = frozenset({
|
||||
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
|
||||
})
|
||||
|
||||
# Regex for credential patterns to strip from error messages
|
||||
_CREDENTIAL_PATTERN = re.compile(
|
||||
r"(?:"
|
||||
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
|
||||
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
|
||||
r"|Bearer\s+\S+" # Bearer token
|
||||
r"|token=[^\s&,;\"']{1,255}" # token=...
|
||||
r"|key=[^\s&,;\"']{1,255}" # key=...
|
||||
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
|
||||
r"|password=[^\s&,;\"']{1,255}" # password=...
|
||||
r"|secret=[^\s&,;\"']{1,255}" # secret=...
|
||||
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
|
||||
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
|
||||
r"|Bearer\s+\S+" # Bearer token
|
||||
r"|token=[^\s&,;\"']{1,255}" # token=...
|
||||
r"|key=[^\s&,;\"']{1,255}" # key=...
|
||||
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
|
||||
r"|password=[^\s&,;\"']{1,255}" # password=...
|
||||
r"|secret=[^\s&,;\"']{1,255}" # secret=...
|
||||
r")",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
@@ -160,8 +148,7 @@ _CREDENTIAL_PATTERN = re.compile(
|
||||
# Security helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_safe_env(user_env: dict | None) -> dict:
|
||||
def _build_safe_env(user_env: Optional[dict]) -> dict:
|
||||
"""Build a filtered environment dict for stdio subprocesses.
|
||||
|
||||
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
|
||||
@@ -193,7 +180,6 @@ def _sanitize_error(text: str) -> str:
|
||||
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _safe_numeric(value, default, coerce=int, minimum=1):
|
||||
"""Coerce a config value to a numeric type, returning *default* on failure.
|
||||
|
||||
@@ -230,22 +216,18 @@ class SamplingHandler:
|
||||
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
|
||||
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
|
||||
self.max_tool_rounds = _safe_numeric(
|
||||
config.get("max_tool_rounds", 5),
|
||||
5,
|
||||
int,
|
||||
minimum=0,
|
||||
config.get("max_tool_rounds", 5), 5, int, minimum=0,
|
||||
)
|
||||
self.model_override = config.get("model")
|
||||
self.allowed_models = config.get("allowed_models", [])
|
||||
|
||||
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
|
||||
self.audit_level = _log_levels.get(
|
||||
str(config.get("log_level", "info")).lower(),
|
||||
logging.INFO,
|
||||
str(config.get("log_level", "info")).lower(), logging.INFO,
|
||||
)
|
||||
|
||||
# Per-instance state
|
||||
self._rate_timestamps: list[float] = []
|
||||
self._rate_timestamps: List[float] = []
|
||||
self._tool_loop_count = 0
|
||||
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
|
||||
|
||||
@@ -263,7 +245,7 @@ class SamplingHandler:
|
||||
|
||||
# -- Model resolution ----------------------------------------------------
|
||||
|
||||
def _resolve_model(self, preferences) -> str | None:
|
||||
def _resolve_model(self, preferences) -> Optional[str]:
|
||||
"""Config override > server hint > None (use default)."""
|
||||
if self.model_override:
|
||||
return self.model_override
|
||||
@@ -283,7 +265,7 @@ class SamplingHandler:
|
||||
items = block.content if isinstance(block.content, list) else [block.content]
|
||||
return "\n".join(item.text for item in items if hasattr(item, "text"))
|
||||
|
||||
def _convert_messages(self, params) -> list[dict]:
|
||||
def _convert_messages(self, params) -> List[dict]:
|
||||
"""Convert MCP SamplingMessages to OpenAI format.
|
||||
|
||||
Uses ``msg.content_as_list`` (SDK helper) so single-block and
|
||||
@@ -291,47 +273,37 @@ class SamplingHandler:
|
||||
with ``isinstance`` on real SDK types when available, falling back
|
||||
to duck-typing via ``hasattr`` for compatibility.
|
||||
"""
|
||||
messages: list[dict] = []
|
||||
messages: List[dict] = []
|
||||
for msg in params.messages:
|
||||
blocks = (
|
||||
msg.content_as_list
|
||||
if hasattr(msg, "content_as_list")
|
||||
else (msg.content if isinstance(msg.content, list) else [msg.content])
|
||||
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
|
||||
msg.content if isinstance(msg.content, list) else [msg.content]
|
||||
)
|
||||
|
||||
# Separate blocks by kind
|
||||
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
|
||||
tool_uses = [
|
||||
b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")
|
||||
]
|
||||
content_blocks = [
|
||||
b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))
|
||||
]
|
||||
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
|
||||
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
|
||||
|
||||
# Emit tool result messages (role: tool)
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.toolUseId,
|
||||
"content": self._extract_tool_result_text(tr),
|
||||
}
|
||||
)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.toolUseId,
|
||||
"content": self._extract_tool_result_text(tr),
|
||||
})
|
||||
|
||||
# Emit assistant tool_calls message
|
||||
if tool_uses:
|
||||
tc_list = []
|
||||
for tu in tool_uses:
|
||||
tc_list.append(
|
||||
{
|
||||
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tu.name,
|
||||
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
|
||||
},
|
||||
}
|
||||
)
|
||||
tc_list.append({
|
||||
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tu.name,
|
||||
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
|
||||
},
|
||||
})
|
||||
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
|
||||
# Include any accompanying text
|
||||
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
|
||||
@@ -348,12 +320,10 @@ class SamplingHandler:
|
||||
if hasattr(block, "text"):
|
||||
parts.append({"type": "text", "text": block.text})
|
||||
elif hasattr(block, "data") and hasattr(block, "mimeType"):
|
||||
parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
|
||||
}
|
||||
)
|
||||
parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
|
||||
})
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported sampling content block type: %s (skipped)",
|
||||
@@ -382,13 +352,16 @@ class SamplingHandler:
|
||||
# Tool loop governance
|
||||
if self.max_tool_rounds == 0:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)")
|
||||
return self._error(
|
||||
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
|
||||
)
|
||||
|
||||
self._tool_loop_count += 1
|
||||
if self._tool_loop_count > self.max_tool_rounds:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loop limit exceeded for server '{self.server_name}' (max {self.max_tool_rounds} rounds)"
|
||||
f"Tool loop limit exceeded for server '{self.server_name}' "
|
||||
f"(max {self.max_tool_rounds} rounds)"
|
||||
)
|
||||
|
||||
content_blocks = []
|
||||
@@ -399,28 +372,25 @@ class SamplingHandler:
|
||||
parsed = json.loads(args)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(
|
||||
"MCP server '%s': malformed tool_calls arguments from LLM (wrapping as raw): %.100s",
|
||||
self.server_name,
|
||||
args,
|
||||
"MCP server '%s': malformed tool_calls arguments "
|
||||
"from LLM (wrapping as raw): %.100s",
|
||||
self.server_name, args,
|
||||
)
|
||||
parsed = {"_raw": args}
|
||||
else:
|
||||
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
|
||||
|
||||
content_blocks.append(
|
||||
ToolUseContent(
|
||||
type="tool_use",
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
input=parsed,
|
||||
)
|
||||
)
|
||||
content_blocks.append(ToolUseContent(
|
||||
type="tool_use",
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
input=parsed,
|
||||
))
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
|
||||
self.server_name,
|
||||
response.model,
|
||||
self.server_name, response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
len(content_blocks),
|
||||
)
|
||||
@@ -440,8 +410,7 @@ class SamplingHandler:
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s",
|
||||
self.server_name,
|
||||
response.model,
|
||||
self.server_name, response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
)
|
||||
|
||||
@@ -476,12 +445,12 @@ class SamplingHandler:
|
||||
if not self._check_rate_limit():
|
||||
logger.warning(
|
||||
"MCP server '%s' sampling rate limit exceeded (%d/min)",
|
||||
self.server_name,
|
||||
self.max_rpm,
|
||||
self.server_name, self.max_rpm,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling rate limit exceeded for server '{self.server_name}' ({self.max_rpm} requests/minute)"
|
||||
f"Sampling rate limit exceeded for server '{self.server_name}' "
|
||||
f"({self.max_rpm} requests/minute)"
|
||||
)
|
||||
|
||||
# Resolve model
|
||||
@@ -489,7 +458,6 @@ class SamplingHandler:
|
||||
|
||||
# Get auxiliary LLM client
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
|
||||
client, default_model = get_text_auxiliary_client()
|
||||
if client is None:
|
||||
self.metrics["errors"] += 1
|
||||
@@ -501,8 +469,7 @@ class SamplingHandler:
|
||||
if self.allowed_models and resolved_model not in self.allowed_models:
|
||||
logger.warning(
|
||||
"MCP server '%s' requested model '%s' not in allowed_models",
|
||||
self.server_name,
|
||||
resolved_model,
|
||||
self.server_name, resolved_model,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
@@ -548,10 +515,7 @@ class SamplingHandler:
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
|
||||
self.server_name,
|
||||
resolved_model,
|
||||
max_tokens,
|
||||
len(messages),
|
||||
self.server_name, resolved_model, max_tokens, len(messages),
|
||||
)
|
||||
|
||||
# Offload sync LLM call to thread (non-blocking)
|
||||
@@ -560,15 +524,19 @@ class SamplingHandler:
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(_sync_call),
|
||||
timeout=self.timeout,
|
||||
asyncio.to_thread(_sync_call), timeout=self.timeout,
|
||||
)
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(f"Sampling LLM call timed out after {self.timeout}s for server '{self.server_name}'")
|
||||
return self._error(
|
||||
f"Sampling LLM call timed out after {self.timeout}s "
|
||||
f"for server '{self.server_name}'"
|
||||
)
|
||||
except Exception as exc:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(f"Sampling LLM call failed: {_sanitize_error(str(exc))}")
|
||||
return self._error(
|
||||
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
|
||||
)
|
||||
|
||||
# Track metrics
|
||||
choice = response.choices[0]
|
||||
@@ -578,7 +546,11 @@ class SamplingHandler:
|
||||
self.metrics["tokens_used"] += total_tokens
|
||||
|
||||
# Dispatch based on response type
|
||||
if choice.finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
|
||||
if (
|
||||
choice.finish_reason == "tool_calls"
|
||||
and hasattr(choice.message, "tool_calls")
|
||||
and choice.message.tool_calls
|
||||
):
|
||||
return self._build_tool_use_result(choice, response)
|
||||
|
||||
return self._build_text_result(choice, response)
|
||||
@@ -588,7 +560,6 @@ class SamplingHandler:
|
||||
# Server task -- each MCP server lives in one long-lived asyncio Task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MCPServerTask:
|
||||
"""Manages a single MCP server connection in a dedicated asyncio Task.
|
||||
|
||||
@@ -600,29 +571,22 @@ class MCPServerTask:
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"name",
|
||||
"session",
|
||||
"tool_timeout",
|
||||
"_task",
|
||||
"_ready",
|
||||
"_shutdown_event",
|
||||
"_tools",
|
||||
"_error",
|
||||
"_config",
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"_sampling",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.session: Any | None = None
|
||||
self.session: Optional[Any] = None
|
||||
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
|
||||
self._task: asyncio.Task | None = None
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._ready = asyncio.Event()
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._tools: list = []
|
||||
self._error: Exception | None = None
|
||||
self._error: Optional[Exception] = None
|
||||
self._config: dict = {}
|
||||
self._sampling: SamplingHandler | None = None
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
@@ -635,7 +599,9 @@ class MCPServerTask:
|
||||
user_env = config.get("env")
|
||||
|
||||
if not command:
|
||||
raise ValueError(f"MCP server '{self.name}' has no 'command' in config")
|
||||
raise ValueError(
|
||||
f"MCP server '{self.name}' has no 'command' in config"
|
||||
)
|
||||
|
||||
safe_env = _build_safe_env(user_env)
|
||||
server_params = StdioServerParameters(
|
||||
@@ -684,7 +650,11 @@ class MCPServerTask:
|
||||
if self.session is None:
|
||||
return
|
||||
tools_result = await self.session.list_tools()
|
||||
self._tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
||||
self._tools = (
|
||||
tools_result.tools
|
||||
if hasattr(tools_result, "tools")
|
||||
else []
|
||||
)
|
||||
|
||||
async def run(self, config: dict):
|
||||
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
|
||||
@@ -734,28 +704,24 @@ class MCPServerTask:
|
||||
if self._shutdown_event.is_set():
|
||||
logger.debug(
|
||||
"MCP server '%s' disconnected during shutdown: %s",
|
||||
self.name,
|
||||
exc,
|
||||
self.name, exc,
|
||||
)
|
||||
return
|
||||
|
||||
retries += 1
|
||||
if retries > _MAX_RECONNECT_RETRIES:
|
||||
logger.warning(
|
||||
"MCP server '%s' failed after %d reconnection attempts, giving up: %s",
|
||||
self.name,
|
||||
_MAX_RECONNECT_RETRIES,
|
||||
exc,
|
||||
"MCP server '%s' failed after %d reconnection attempts, "
|
||||
"giving up: %s",
|
||||
self.name, _MAX_RECONNECT_RETRIES, exc,
|
||||
)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"MCP server '%s' connection lost (attempt %d/%d), reconnecting in %.0fs: %s",
|
||||
self.name,
|
||||
retries,
|
||||
_MAX_RECONNECT_RETRIES,
|
||||
backoff,
|
||||
exc,
|
||||
"MCP server '%s' connection lost (attempt %d/%d), "
|
||||
"reconnecting in %.0fs: %s",
|
||||
self.name, retries, _MAX_RECONNECT_RETRIES,
|
||||
backoff, exc,
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
|
||||
@@ -779,7 +745,7 @@ class MCPServerTask:
|
||||
if self._task and not self._task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._task, timeout=10)
|
||||
except TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"MCP server '%s' shutdown timed out, cancelling task",
|
||||
self.name,
|
||||
@@ -796,11 +762,11 @@ class MCPServerTask:
|
||||
# Module-level state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_servers: dict[str, MCPServerTask] = {}
|
||||
_servers: Dict[str, MCPServerTask] = {}
|
||||
|
||||
# Dedicated event loop running in a background daemon thread.
|
||||
_mcp_loop: asyncio.AbstractEventLoop | None = None
|
||||
_mcp_thread: threading.Thread | None = None
|
||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
|
||||
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
|
||||
_lock = threading.Lock()
|
||||
@@ -835,8 +801,7 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
|
||||
# Config loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_mcp_config() -> dict[str, dict]:
|
||||
def _load_mcp_config() -> Dict[str, dict]:
|
||||
"""Read ``mcp_servers`` from the Hermes config file.
|
||||
|
||||
Returns a dict of ``{server_name: server_config}`` or empty dict.
|
||||
@@ -846,7 +811,6 @@ def _load_mcp_config() -> dict[str, dict]:
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers")
|
||||
if not servers or not isinstance(servers, dict):
|
||||
@@ -861,7 +825,6 @@ def _load_mcp_config() -> dict[str, dict]:
|
||||
# Server connection helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
||||
"""Create an MCPServerTask, start it, and return when ready.
|
||||
|
||||
@@ -882,7 +845,6 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
||||
# Handler / check-fn factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
"""Return a sync handler that calls an MCP tool via the background loop.
|
||||
|
||||
@@ -894,21 +856,27 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.call_tool(tool_name, arguments=args)
|
||||
# MCP CallToolResult has .content (list of content blocks) and .isError
|
||||
if result.isError:
|
||||
error_text = ""
|
||||
for block in result.content or []:
|
||||
for block in (result.content or []):
|
||||
if hasattr(block, "text"):
|
||||
error_text += block.text
|
||||
return json.dumps({"error": _sanitize_error(error_text or "MCP tool returned an error")})
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
error_text or "MCP tool returned an error"
|
||||
)
|
||||
})
|
||||
|
||||
# Collect text from content blocks
|
||||
parts: list[str] = []
|
||||
for block in result.content or []:
|
||||
parts: List[str] = []
|
||||
for block in (result.content or []):
|
||||
if hasattr(block, "text"):
|
||||
parts.append(block.text)
|
||||
return json.dumps({"result": "\n".join(parts) if parts else ""})
|
||||
@@ -918,11 +886,13 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP tool %s/%s call failed: %s",
|
||||
server_name,
|
||||
tool_name,
|
||||
exc,
|
||||
server_name, tool_name, exc,
|
||||
)
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -934,12 +904,14 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.list_resources()
|
||||
resources = []
|
||||
for r in result.resources if hasattr(result, "resources") else []:
|
||||
for r in (result.resources if hasattr(result, "resources") else []):
|
||||
entry = {}
|
||||
if hasattr(r, "uri"):
|
||||
entry["uri"] = str(r.uri)
|
||||
@@ -956,11 +928,13 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/list_resources failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
"MCP %s/list_resources failed: %s", server_name, exc,
|
||||
)
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -972,7 +946,9 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
uri = args.get("uri")
|
||||
if not uri:
|
||||
@@ -981,7 +957,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
||||
async def _call():
|
||||
result = await server.session.read_resource(uri)
|
||||
# read_resource returns ReadResourceResult with .contents list
|
||||
parts: list[str] = []
|
||||
parts: List[str] = []
|
||||
contents = result.contents if hasattr(result, "contents") else []
|
||||
for block in contents:
|
||||
if hasattr(block, "text"):
|
||||
@@ -994,11 +970,13 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/read_resource failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
"MCP %s/read_resource failed: %s", server_name, exc,
|
||||
)
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -1010,12 +988,14 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.list_prompts()
|
||||
prompts = []
|
||||
for p in result.prompts if hasattr(result, "prompts") else []:
|
||||
for p in (result.prompts if hasattr(result, "prompts") else []):
|
||||
entry = {}
|
||||
if hasattr(p, "name"):
|
||||
entry["name"] = p.name
|
||||
@@ -1037,11 +1017,13 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/list_prompts failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
"MCP %s/list_prompts failed: %s", server_name, exc,
|
||||
)
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -1053,7 +1035,9 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
name = args.get("name")
|
||||
if not name:
|
||||
@@ -1064,7 +1048,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
||||
result = await server.session.get_prompt(name, arguments=arguments)
|
||||
# GetPromptResult has .messages list
|
||||
messages = []
|
||||
for msg in result.messages if hasattr(result, "messages") else []:
|
||||
for msg in (result.messages if hasattr(result, "messages") else []):
|
||||
entry = {}
|
||||
if hasattr(msg, "role"):
|
||||
entry["role"] = msg.role
|
||||
@@ -1086,11 +1070,13 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/get_prompt failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
"MCP %s/get_prompt failed: %s", server_name, exc,
|
||||
)
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -1110,7 +1096,6 @@ def _make_check_fn(server_name: str):
|
||||
# Discovery & registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
"""Convert an MCP tool listing to the Hermes registry schema format.
|
||||
|
||||
@@ -1129,16 +1114,14 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
return {
|
||||
"name": prefixed_name,
|
||||
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
|
||||
"parameters": mcp_tool.inputSchema
|
||||
if mcp_tool.inputSchema
|
||||
else {
|
||||
"parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_utility_schemas(server_name: str) -> list[dict]:
|
||||
def _build_utility_schemas(server_name: str) -> List[dict]:
|
||||
"""Build schemas for the MCP utility tools (resources & prompts).
|
||||
|
||||
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
|
||||
@@ -1209,9 +1192,9 @@ def _build_utility_schemas(server_name: str) -> list[dict]:
|
||||
]
|
||||
|
||||
|
||||
def _existing_tool_names() -> list[str]:
|
||||
def _existing_tool_names() -> List[str]:
|
||||
"""Return tool names for all currently connected servers."""
|
||||
names: list[str] = []
|
||||
names: List[str] = []
|
||||
for sname, server in _servers.items():
|
||||
for mcp_tool in server._tools:
|
||||
schema = _convert_mcp_schema(sname, mcp_tool)
|
||||
@@ -1222,7 +1205,7 @@ def _existing_tool_names() -> list[str]:
|
||||
return names
|
||||
|
||||
|
||||
async def _discover_and_register_server(name: str, config: dict) -> list[str]:
|
||||
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
"""Connect to a single MCP server, discover tools, and register them.
|
||||
|
||||
Also registers utility tools for MCP Resources and Prompts support
|
||||
@@ -1241,7 +1224,7 @@ async def _discover_and_register_server(name: str, config: dict) -> list[str]:
|
||||
with _lock:
|
||||
_servers[name] = server
|
||||
|
||||
registered_names: list[str] = []
|
||||
registered_names: List[str] = []
|
||||
toolset_name = f"mcp-{name}"
|
||||
|
||||
for mcp_tool in server._tools:
|
||||
@@ -1294,9 +1277,7 @@ async def _discover_and_register_server(name: str, config: dict) -> list[str]:
|
||||
transport_type = "HTTP" if "url" in config else "stdio"
|
||||
logger.info(
|
||||
"MCP server '%s' (%s): registered %d tool(s): %s",
|
||||
name,
|
||||
transport_type,
|
||||
len(registered_names),
|
||||
name, transport_type, len(registered_names),
|
||||
", ".join(registered_names),
|
||||
)
|
||||
return registered_names
|
||||
@@ -1306,8 +1287,7 @@ async def _discover_and_register_server(name: str, config: dict) -> list[str]:
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def discover_mcp_tools() -> list[str]:
|
||||
def discover_mcp_tools() -> List[str]:
|
||||
"""Entry point: load config, connect to MCP servers, register tools.
|
||||
|
||||
Called from ``model_tools._discover_tools()``. Safe to call even when
|
||||
@@ -1338,12 +1318,12 @@ def discover_mcp_tools() -> list[str]:
|
||||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
all_tools: list[str] = []
|
||||
all_tools: List[str] = []
|
||||
failed_count = 0
|
||||
|
||||
async def _discover_one(name: str, cfg: dict) -> list[str]:
|
||||
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
||||
"""Connect to a single server and return its registered tool names."""
|
||||
transport_desc = cfg.get("url", f"{cfg.get('command', '?')} {' '.join(cfg.get('args', [])[:2])}")
|
||||
transport_desc = cfg.get("url", f'{cfg.get("command", "?")} {" ".join(cfg.get("args", [])[:2])}')
|
||||
try:
|
||||
registered = await _discover_and_register_server(name, cfg)
|
||||
transport_type = "HTTP" if "url" in cfg else "stdio"
|
||||
@@ -1351,8 +1331,7 @@ def discover_mcp_tools() -> list[str]:
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s': %s",
|
||||
name,
|
||||
exc,
|
||||
name, exc,
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -1379,7 +1358,6 @@ def discover_mcp_tools() -> list[str]:
|
||||
if all_tools:
|
||||
# Dynamically inject into all hermes-* platform toolsets
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
for tool_name in all_tools:
|
||||
@@ -1399,13 +1377,13 @@ def discover_mcp_tools() -> list[str]:
|
||||
return _existing_tool_names()
|
||||
|
||||
|
||||
def get_mcp_status() -> list[dict]:
|
||||
def get_mcp_status() -> List[dict]:
|
||||
"""Return status of all configured MCP servers for banner display.
|
||||
|
||||
Returns a list of dicts with keys: name, transport, tools, connected.
|
||||
Includes both successfully connected servers and configured-but-failed ones.
|
||||
"""
|
||||
result: list[dict] = []
|
||||
result: List[dict] = []
|
||||
|
||||
# Get configured servers from config
|
||||
configured = _load_mcp_config()
|
||||
@@ -1429,14 +1407,12 @@ def get_mcp_status() -> list[dict]:
|
||||
entry["sampling"] = dict(server._sampling.metrics)
|
||||
result.append(entry)
|
||||
else:
|
||||
result.append(
|
||||
{
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": 0,
|
||||
"connected": False,
|
||||
}
|
||||
)
|
||||
result.append({
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": 0,
|
||||
"connected": False,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@@ -1464,9 +1440,7 @@ def shutdown_mcp_servers():
|
||||
for server, result in zip(servers_snapshot, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.debug(
|
||||
"Error closing MCP server '%s': %s",
|
||||
server.name,
|
||||
result,
|
||||
"Error closing MCP server '%s': %s", server.name, result,
|
||||
)
|
||||
with _lock:
|
||||
_servers.clear()
|
||||
|
||||
@@ -29,7 +29,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,38 +46,30 @@ ENTRY_DELIMITER = "\n§\n"
|
||||
|
||||
_MEMORY_THREAT_PATTERNS = [
|
||||
# Prompt injection
|
||||
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
|
||||
(r"you\s+are\s+now\s+", "role_hijack"),
|
||||
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
|
||||
(r"system\s+prompt\s+override", "sys_prompt_override"),
|
||||
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
|
||||
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
|
||||
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
|
||||
(r'you\s+are\s+now\s+', "role_hijack"),
|
||||
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
|
||||
(r'system\s+prompt\s+override', "sys_prompt_override"),
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
|
||||
# Exfiltration via curl/wget with secrets
|
||||
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
|
||||
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
|
||||
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)", "read_secrets"),
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)', "read_secrets"),
|
||||
# Persistence via shell rc
|
||||
(r"authorized_keys", "ssh_backdoor"),
|
||||
(r"\$HOME/\.ssh|\~/\.ssh", "ssh_access"),
|
||||
(r"\$HOME/\.hermes/\.env|\~/\.hermes/\.env", "hermes_env"),
|
||||
(r'authorized_keys', "ssh_backdoor"),
|
||||
(r'\$HOME/\.ssh|\~/\.ssh', "ssh_access"),
|
||||
(r'\$HOME/\.hermes/\.env|\~/\.hermes/\.env', "hermes_env"),
|
||||
]
|
||||
|
||||
# Subset of invisible chars for injection detection
|
||||
_INVISIBLE_CHARS = {
|
||||
"\u200b",
|
||||
"\u200c",
|
||||
"\u200d",
|
||||
"\u2060",
|
||||
"\ufeff",
|
||||
"\u202a",
|
||||
"\u202b",
|
||||
"\u202c",
|
||||
"\u202d",
|
||||
"\u202e",
|
||||
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
|
||||
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
|
||||
}
|
||||
|
||||
|
||||
def _scan_memory_content(content: str) -> str | None:
|
||||
def _scan_memory_content(content: str) -> Optional[str]:
|
||||
"""Scan memory content for injection/exfil patterns. Returns error string if blocked."""
|
||||
# Check invisible unicode
|
||||
for char in _INVISIBLE_CHARS:
|
||||
@@ -104,12 +96,12 @@ class MemoryStore:
|
||||
"""
|
||||
|
||||
def __init__(self, memory_char_limit: int = 2200, user_char_limit: int = 1375):
|
||||
self.memory_entries: list[str] = []
|
||||
self.user_entries: list[str] = []
|
||||
self.memory_entries: List[str] = []
|
||||
self.user_entries: List[str] = []
|
||||
self.memory_char_limit = memory_char_limit
|
||||
self.user_char_limit = user_char_limit
|
||||
# Frozen snapshot for system prompt -- set once at load_from_disk()
|
||||
self._system_prompt_snapshot: dict[str, str] = {"memory": "", "user": ""}
|
||||
self._system_prompt_snapshot: Dict[str, str] = {"memory": "", "user": ""}
|
||||
|
||||
def load_from_disk(self):
|
||||
"""Load entries from MEMORY.md and USER.md, capture system prompt snapshot."""
|
||||
@@ -137,12 +129,12 @@ class MemoryStore:
|
||||
elif target == "user":
|
||||
self._write_file(MEMORY_DIR / "USER.md", self.user_entries)
|
||||
|
||||
def _entries_for(self, target: str) -> list[str]:
|
||||
def _entries_for(self, target: str) -> List[str]:
|
||||
if target == "user":
|
||||
return self.user_entries
|
||||
return self.memory_entries
|
||||
|
||||
def _set_entries(self, target: str, entries: list[str]):
|
||||
def _set_entries(self, target: str, entries: List[str]):
|
||||
if target == "user":
|
||||
self.user_entries = entries
|
||||
else:
|
||||
@@ -159,7 +151,7 @@ class MemoryStore:
|
||||
return self.user_char_limit
|
||||
return self.memory_char_limit
|
||||
|
||||
def add(self, target: str, content: str) -> dict[str, Any]:
|
||||
def add(self, target: str, content: str) -> Dict[str, Any]:
|
||||
"""Append a new entry. Returns error if it would exceed the char limit."""
|
||||
content = content.strip()
|
||||
if not content:
|
||||
@@ -200,7 +192,7 @@ class MemoryStore:
|
||||
|
||||
return self._success_response(target, "Entry added.")
|
||||
|
||||
def replace(self, target: str, old_text: str, new_content: str) -> dict[str, Any]:
|
||||
def replace(self, target: str, old_text: str, new_content: str) -> Dict[str, Any]:
|
||||
"""Find entry containing old_text substring, replace it with new_content."""
|
||||
old_text = old_text.strip()
|
||||
new_content = new_content.strip()
|
||||
@@ -255,7 +247,7 @@ class MemoryStore:
|
||||
|
||||
return self._success_response(target, "Entry replaced.")
|
||||
|
||||
def remove(self, target: str, old_text: str) -> dict[str, Any]:
|
||||
def remove(self, target: str, old_text: str) -> Dict[str, Any]:
|
||||
"""Remove the entry containing old_text substring."""
|
||||
old_text = old_text.strip()
|
||||
if not old_text:
|
||||
@@ -286,7 +278,7 @@ class MemoryStore:
|
||||
|
||||
return self._success_response(target, "Entry removed.")
|
||||
|
||||
def format_for_system_prompt(self, target: str) -> str | None:
|
||||
def format_for_system_prompt(self, target: str) -> Optional[str]:
|
||||
"""
|
||||
Return the frozen snapshot for system prompt injection.
|
||||
|
||||
@@ -301,7 +293,7 @@ class MemoryStore:
|
||||
|
||||
# -- Internal helpers --
|
||||
|
||||
def _success_response(self, target: str, message: str = None) -> dict[str, Any]:
|
||||
def _success_response(self, target: str, message: str = None) -> Dict[str, Any]:
|
||||
entries = self._entries_for(target)
|
||||
current = self._char_count(target)
|
||||
limit = self._char_limit(target)
|
||||
@@ -318,7 +310,7 @@ class MemoryStore:
|
||||
resp["message"] = message
|
||||
return resp
|
||||
|
||||
def _render_block(self, target: str, entries: list[str]) -> str:
|
||||
def _render_block(self, target: str, entries: List[str]) -> str:
|
||||
"""Render a system prompt block with header and usage indicator."""
|
||||
if not entries:
|
||||
return ""
|
||||
@@ -337,7 +329,7 @@ class MemoryStore:
|
||||
return f"{separator}\n{header}\n{separator}\n{content}"
|
||||
|
||||
@staticmethod
|
||||
def _read_file(path: Path) -> list[str]:
|
||||
def _read_file(path: Path) -> List[str]:
|
||||
"""Read a memory file and split into entries.
|
||||
|
||||
No file locking needed: _write_file uses atomic rename, so readers
|
||||
@@ -347,7 +339,7 @@ class MemoryStore:
|
||||
return []
|
||||
try:
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
except OSError:
|
||||
except (OSError, IOError):
|
||||
return []
|
||||
|
||||
if not raw.strip():
|
||||
@@ -359,7 +351,7 @@ class MemoryStore:
|
||||
return [e for e in entries if e]
|
||||
|
||||
@staticmethod
|
||||
def _write_file(path: Path, entries: list[str]):
|
||||
def _write_file(path: Path, entries: List[str]):
|
||||
"""Write entries to a memory file using atomic temp-file + rename.
|
||||
|
||||
Previous implementation used open("w") + flock, but "w" truncates the
|
||||
@@ -370,7 +362,9 @@ class MemoryStore:
|
||||
content = ENTRY_DELIMITER.join(entries) if entries else ""
|
||||
try:
|
||||
# Write to temp file in same directory (same filesystem for atomic rename)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp", prefix=".mem_")
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=str(path.parent), suffix=".tmp", prefix=".mem_"
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
@@ -384,7 +378,7 @@ class MemoryStore:
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
except OSError as e:
|
||||
except (OSError, IOError) as e:
|
||||
raise RuntimeError(f"Failed to write memory file {path}: {e}")
|
||||
|
||||
|
||||
@@ -393,7 +387,7 @@ def memory_tool(
|
||||
target: str = "memory",
|
||||
content: str = None,
|
||||
old_text: str = None,
|
||||
store: MemoryStore | None = None,
|
||||
store: Optional[MemoryStore] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Single entry point for the memory tool. Dispatches to MemoryStore methods.
|
||||
@@ -401,15 +395,10 @@ def memory_tool(
|
||||
Returns JSON string with results.
|
||||
"""
|
||||
if store is None:
|
||||
return json.dumps(
|
||||
{"success": False, "error": "Memory is not available. It may be disabled in config or this environment."},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
return json.dumps({"success": False, "error": "Memory is not available. It may be disabled in config or this environment."}, ensure_ascii=False)
|
||||
|
||||
if target not in ("memory", "user"):
|
||||
return json.dumps(
|
||||
{"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False
|
||||
)
|
||||
return json.dumps({"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False)
|
||||
|
||||
if action == "add":
|
||||
if not content:
|
||||
@@ -418,26 +407,18 @@ def memory_tool(
|
||||
|
||||
elif action == "replace":
|
||||
if not old_text:
|
||||
return json.dumps(
|
||||
{"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False
|
||||
)
|
||||
return json.dumps({"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False)
|
||||
if not content:
|
||||
return json.dumps(
|
||||
{"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False
|
||||
)
|
||||
return json.dumps({"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False)
|
||||
result = store.replace(target, old_text, content)
|
||||
|
||||
elif action == "remove":
|
||||
if not old_text:
|
||||
return json.dumps(
|
||||
{"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False
|
||||
)
|
||||
return json.dumps({"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False)
|
||||
result = store.remove(target, old_text)
|
||||
|
||||
else:
|
||||
return json.dumps(
|
||||
{"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False
|
||||
)
|
||||
return json.dumps({"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
@@ -476,16 +457,23 @@ MEMORY_SCHEMA = {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {"type": "string", "enum": ["add", "replace", "remove"], "description": "The action to perform."},
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "replace", "remove"],
|
||||
"description": "The action to perform."
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"enum": ["memory", "user"],
|
||||
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile.",
|
||||
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The entry content. Required for 'add' and 'replace'."
|
||||
},
|
||||
"content": {"type": "string", "description": "The entry content. Required for 'add' and 'replace'."},
|
||||
"old_text": {
|
||||
"type": "string",
|
||||
"description": "Short unique substring identifying the entry to replace or remove.",
|
||||
"description": "Short unique substring identifying the entry to replace or remove."
|
||||
},
|
||||
},
|
||||
"required": ["action", "target"],
|
||||
@@ -505,7 +493,10 @@ registry.register(
|
||||
target=args.get("target", "memory"),
|
||||
content=args.get("content"),
|
||||
old_text=args.get("old_text"),
|
||||
store=kw.get("store"),
|
||||
),
|
||||
store=kw.get("store")),
|
||||
check_fn=check_memory_requirements,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -38,27 +38,21 @@ Configuration:
|
||||
Usage:
|
||||
from mixture_of_agents_tool import mixture_of_agents_tool
|
||||
import asyncio
|
||||
|
||||
|
||||
# Process a complex query
|
||||
result = await mixture_of_agents_tool(
|
||||
user_prompt="Solve this complex mathematical proof..."
|
||||
)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tools.openrouter_client import get_async_client as _get_openrouter_client, check_api_key as check_openrouter_api_key
|
||||
from tools.debug_helpers import DebugSession
|
||||
from tools.openrouter_client import (
|
||||
check_api_key as check_openrouter_api_key,
|
||||
)
|
||||
from tools.openrouter_client import (
|
||||
get_async_client as _get_openrouter_client,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -66,9 +60,9 @@ logger = logging.getLogger(__name__)
|
||||
# Reference models - these generate diverse initial responses in parallel (OpenRouter slugs)
|
||||
REFERENCE_MODELS = [
|
||||
"anthropic/claude-opus-4.5",
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-3-pro-preview",
|
||||
"openai/gpt-5.2-pro",
|
||||
"deepseek/deepseek-v3.2",
|
||||
"deepseek/deepseek-v3.2"
|
||||
]
|
||||
|
||||
# Aggregator model - synthesizes reference responses into final output
|
||||
@@ -89,18 +83,18 @@ Responses from models:"""
|
||||
_debug = DebugSession("moa_tools", env_var="MOA_TOOLS_DEBUG")
|
||||
|
||||
|
||||
def _construct_aggregator_prompt(system_prompt: str, responses: list[str]) -> str:
|
||||
def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str:
|
||||
"""
|
||||
Construct the final system prompt for the aggregator including all model responses.
|
||||
|
||||
|
||||
Args:
|
||||
system_prompt (str): Base system prompt for aggregation
|
||||
responses (List[str]): List of responses from reference models
|
||||
|
||||
|
||||
Returns:
|
||||
str: Complete system prompt with enumerated responses
|
||||
"""
|
||||
response_text = "\n".join([f"{i + 1}. {response}" for i, response in enumerate(responses)])
|
||||
response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)])
|
||||
return f"{system_prompt}\n\n{response_text}"
|
||||
|
||||
|
||||
@@ -109,43 +103,48 @@ async def _run_reference_model_safe(
|
||||
user_prompt: str,
|
||||
temperature: float = REFERENCE_TEMPERATURE,
|
||||
max_tokens: int = 32000,
|
||||
max_retries: int = 6,
|
||||
max_retries: int = 6
|
||||
) -> tuple[str, str, bool]:
|
||||
"""
|
||||
Run a single reference model with retry logic and graceful failure handling.
|
||||
|
||||
|
||||
Args:
|
||||
model (str): Model identifier to use
|
||||
user_prompt (str): The user's query
|
||||
temperature (float): Sampling temperature for response generation
|
||||
max_tokens (int): Maximum tokens in response
|
||||
max_retries (int): Maximum number of retry attempts
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[str, str, bool]: (model_name, response_content_or_error, success_flag)
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info("Querying %s (attempt %s/%s)", model, attempt + 1, max_retries)
|
||||
|
||||
|
||||
# Build parameters for the API call
|
||||
api_params = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": user_prompt}],
|
||||
"extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}},
|
||||
"extra_body": {
|
||||
"reasoning": {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||
# Only include temperature for non-GPT models
|
||||
if not model.lower().startswith("gpt-"):
|
||||
if not model.lower().startswith('gpt-'):
|
||||
api_params["temperature"] = temperature
|
||||
|
||||
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
logger.info("%s responded (%s characters)", model, len(content))
|
||||
return model, content, True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Log more detailed error information for debugging
|
||||
@@ -155,7 +154,7 @@ async def _run_reference_model_safe(
|
||||
logger.warning("%s rate limit error (attempt %s): %s", model, attempt + 1, error_str)
|
||||
else:
|
||||
logger.warning("%s unknown error (attempt %s): %s", model, attempt + 1, error_str)
|
||||
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s
|
||||
sleep_time = min(2 ** (attempt + 1), 60)
|
||||
@@ -168,47 +167,60 @@ async def _run_reference_model_safe(
|
||||
|
||||
|
||||
async def _run_aggregator_model(
|
||||
system_prompt: str, user_prompt: str, temperature: float = AGGREGATOR_TEMPERATURE, max_tokens: int = None
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float = AGGREGATOR_TEMPERATURE,
|
||||
max_tokens: int = None
|
||||
) -> str:
|
||||
"""
|
||||
Run the aggregator model to synthesize the final response.
|
||||
|
||||
|
||||
Args:
|
||||
system_prompt (str): System prompt with all reference responses
|
||||
user_prompt (str): Original user query
|
||||
temperature (float): Focused temperature for consistent aggregation
|
||||
max_tokens (int): Maximum tokens in final response
|
||||
|
||||
|
||||
Returns:
|
||||
str: Synthesized final response
|
||||
"""
|
||||
logger.info("Running aggregator model: %s", AGGREGATOR_MODEL)
|
||||
|
||||
|
||||
# Build parameters for the API call
|
||||
api_params = {
|
||||
"model": AGGREGATOR_MODEL,
|
||||
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
||||
"extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}},
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"extra_body": {
|
||||
"reasoning": {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||
# Only include temperature for non-GPT models
|
||||
if not AGGREGATOR_MODEL.lower().startswith("gpt-"):
|
||||
if not AGGREGATOR_MODEL.lower().startswith('gpt-'):
|
||||
api_params["temperature"] = temperature
|
||||
|
||||
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
logger.info("Aggregation complete (%s characters)", len(content))
|
||||
return content
|
||||
|
||||
|
||||
async def mixture_of_agents_tool(
|
||||
user_prompt: str, reference_models: list[str] | None = None, aggregator_model: str | None = None
|
||||
user_prompt: str,
|
||||
reference_models: Optional[List[str]] = None,
|
||||
aggregator_model: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Process a complex query using the Mixture-of-Agents methodology.
|
||||
|
||||
|
||||
This tool leverages multiple frontier language models to collaboratively solve
|
||||
extremely difficult problems requiring intense reasoning. It's particularly
|
||||
effective for:
|
||||
@@ -217,16 +229,16 @@ async def mixture_of_agents_tool(
|
||||
- Multi-step analytical reasoning tasks
|
||||
- Problems requiring diverse domain expertise
|
||||
- Tasks where single models show limitations
|
||||
|
||||
|
||||
The MoA approach uses a fixed 2-layer architecture:
|
||||
1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6)
|
||||
2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4)
|
||||
|
||||
|
||||
Args:
|
||||
user_prompt (str): The complex query or problem to solve
|
||||
reference_models (Optional[List[str]]): Custom reference models to use
|
||||
aggregator_model (Optional[str]): Custom aggregator model to use
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON string containing the MoA results with the following structure:
|
||||
{
|
||||
@@ -238,12 +250,12 @@ async def mixture_of_agents_tool(
|
||||
},
|
||||
"processing_time": float
|
||||
}
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If MoA processing fails or API key is not set
|
||||
"""
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
|
||||
@@ -251,7 +263,7 @@ async def mixture_of_agents_tool(
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
|
||||
"reference_temperature": REFERENCE_TEMPERATURE,
|
||||
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
@@ -260,152 +272,161 @@ async def mixture_of_agents_tool(
|
||||
"failed_models": [],
|
||||
"final_response_length": 0,
|
||||
"processing_time_seconds": 0,
|
||||
"models_used": {},
|
||||
"models_used": {}
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Starting Mixture-of-Agents processing...")
|
||||
logger.info("Query: %s", user_prompt[:100])
|
||||
|
||||
|
||||
# Validate API key availability
|
||||
if not os.getenv("OPENROUTER_API_KEY"):
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable not set")
|
||||
|
||||
|
||||
# Use provided models or defaults
|
||||
ref_models = reference_models or REFERENCE_MODELS
|
||||
agg_model = aggregator_model or AGGREGATOR_MODEL
|
||||
|
||||
|
||||
logger.info("Using %s reference models in 2-layer MoA architecture", len(ref_models))
|
||||
|
||||
|
||||
# Layer 1: Generate diverse responses from reference models (with failure handling)
|
||||
logger.info("Layer 1: Generating reference responses...")
|
||||
model_results = await asyncio.gather(
|
||||
*[_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE) for model in ref_models]
|
||||
)
|
||||
|
||||
model_results = await asyncio.gather(*[
|
||||
_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE)
|
||||
for model in ref_models
|
||||
])
|
||||
|
||||
# Separate successful and failed responses
|
||||
successful_responses = []
|
||||
failed_models = []
|
||||
|
||||
|
||||
for model_name, content, success in model_results:
|
||||
if success:
|
||||
successful_responses.append(content)
|
||||
else:
|
||||
failed_models.append(model_name)
|
||||
|
||||
|
||||
successful_count = len(successful_responses)
|
||||
failed_count = len(failed_models)
|
||||
|
||||
|
||||
logger.info("Reference model results: %s successful, %s failed", successful_count, failed_count)
|
||||
|
||||
|
||||
if failed_models:
|
||||
logger.warning("Failed models: %s", ", ".join(failed_models))
|
||||
|
||||
logger.warning("Failed models: %s", ', '.join(failed_models))
|
||||
|
||||
# Check if we have enough successful responses to proceed
|
||||
if successful_count < MIN_SUCCESSFUL_REFERENCES:
|
||||
raise ValueError(
|
||||
f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses."
|
||||
)
|
||||
|
||||
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
|
||||
|
||||
debug_call_data["reference_responses_count"] = successful_count
|
||||
debug_call_data["failed_models_count"] = failed_count
|
||||
debug_call_data["failed_models"] = failed_models
|
||||
|
||||
|
||||
# Layer 2: Aggregate responses using the aggregator model
|
||||
logger.info("Layer 2: Synthesizing final response...")
|
||||
aggregator_system_prompt = _construct_aggregator_prompt(AGGREGATOR_SYSTEM_PROMPT, successful_responses)
|
||||
|
||||
final_response = await _run_aggregator_model(aggregator_system_prompt, user_prompt, AGGREGATOR_TEMPERATURE)
|
||||
|
||||
aggregator_system_prompt = _construct_aggregator_prompt(
|
||||
AGGREGATOR_SYSTEM_PROMPT,
|
||||
successful_responses
|
||||
)
|
||||
|
||||
final_response = await _run_aggregator_model(
|
||||
aggregator_system_prompt,
|
||||
user_prompt,
|
||||
AGGREGATOR_TEMPERATURE
|
||||
)
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.datetime.now()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
logger.info("MoA processing completed in %.2f seconds", processing_time)
|
||||
|
||||
|
||||
# Prepare successful response (only final aggregated result, minimal fields)
|
||||
result = {
|
||||
"success": True,
|
||||
"response": final_response,
|
||||
"models_used": {"reference_models": ref_models, "aggregator_model": agg_model},
|
||||
"models_used": {
|
||||
"reference_models": ref_models,
|
||||
"aggregator_model": agg_model
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["final_response_length"] = len(final_response)
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
debug_call_data["models_used"] = result["models_used"]
|
||||
|
||||
|
||||
# Log debug information
|
||||
_debug.log_call("mixture_of_agents_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in MoA processing: {str(e)}"
|
||||
logger.error("%s", error_msg)
|
||||
|
||||
|
||||
# Calculate processing time even for errors
|
||||
end_time = datetime.datetime.now()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
# Prepare error response (minimal fields)
|
||||
result = {
|
||||
"success": False,
|
||||
"response": "MoA processing failed. Please try again or use a single model for this query.",
|
||||
"models_used": {
|
||||
"reference_models": reference_models or REFERENCE_MODELS,
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL
|
||||
},
|
||||
"error": error_msg,
|
||||
"error": error_msg
|
||||
}
|
||||
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
_debug.log_call("mixture_of_agents_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_moa_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for MoA tools are met.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
return check_openrouter_api_key()
|
||||
|
||||
|
||||
def get_debug_session_info() -> dict[str, Any]:
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
return _debug.get_session_info()
|
||||
|
||||
|
||||
def get_available_models() -> dict[str, list[str]]:
|
||||
def get_available_models() -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get information about available models for MoA processing.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Dictionary with reference and aggregator models
|
||||
"""
|
||||
return {
|
||||
"reference_models": REFERENCE_MODELS,
|
||||
"aggregator_models": [AGGREGATOR_MODEL],
|
||||
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL],
|
||||
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL]
|
||||
}
|
||||
|
||||
|
||||
def get_moa_configuration() -> dict[str, Any]:
|
||||
def get_moa_configuration() -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current MoA configuration settings.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing all configuration parameters
|
||||
"""
|
||||
@@ -416,7 +437,7 @@ def get_moa_configuration() -> dict[str, Any]:
|
||||
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
|
||||
"total_reference_models": len(REFERENCE_MODELS),
|
||||
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail",
|
||||
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail"
|
||||
}
|
||||
|
||||
|
||||
@@ -426,10 +447,10 @@ if __name__ == "__main__":
|
||||
"""
|
||||
print("🤖 Mixture-of-Agents Tool Module")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_openrouter_api_key()
|
||||
|
||||
|
||||
if not api_available:
|
||||
print("❌ OPENROUTER_API_KEY environment variable not set")
|
||||
print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'")
|
||||
@@ -437,26 +458,26 @@ if __name__ == "__main__":
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ OpenRouter API key found")
|
||||
|
||||
|
||||
print("🛠️ MoA tools ready for use!")
|
||||
|
||||
|
||||
# Show current configuration
|
||||
config = get_moa_configuration()
|
||||
print("\n⚙️ Current Configuration:")
|
||||
print(f"\n⚙️ Current Configuration:")
|
||||
print(f" 🤖 Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}")
|
||||
print(f" 🧠 Aggregator model: {config['aggregator_model']}")
|
||||
print(f" 🌡️ Reference temperature: {config['reference_temperature']}")
|
||||
print(f" 🌡️ Aggregator temperature: {config['aggregator_temperature']}")
|
||||
print(f" 🛡️ Failure tolerance: {config['failure_tolerance']}")
|
||||
print(f" 📊 Minimum successful models: {config['min_successful_references']}")
|
||||
|
||||
|
||||
# Show debug mode status
|
||||
if _debug.active:
|
||||
print(f"\n🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
||||
print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{_debug.session_id}.json")
|
||||
else:
|
||||
print("\n🐛 Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from mixture_of_agents_tool import mixture_of_agents_tool")
|
||||
print(" import asyncio")
|
||||
@@ -467,26 +488,24 @@ if __name__ == "__main__":
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
|
||||
print("\nBest use cases:")
|
||||
print(" - Complex mathematical proofs and calculations")
|
||||
print(" - Advanced coding problems and algorithm design")
|
||||
print(" - Multi-step analytical reasoning tasks")
|
||||
print(" - Problems requiring diverse domain expertise")
|
||||
print(" - Tasks where single models show limitations")
|
||||
|
||||
|
||||
print("\nPerformance characteristics:")
|
||||
print(" - Higher latency due to multiple model calls")
|
||||
print(" - Significantly improved quality for complex tasks")
|
||||
print(" - Parallel processing for efficiency")
|
||||
print(
|
||||
f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation"
|
||||
)
|
||||
print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation")
|
||||
print(" - Token-efficient: only returns final aggregated response")
|
||||
print(" - Resilient: continues with partial model failures")
|
||||
print(" - Configurable: easy to modify models and settings at top of file")
|
||||
print(f" - Configurable: easy to modify models and settings at top of file")
|
||||
print(" - State-of-the-art results on challenging benchmarks")
|
||||
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export MOA_TOOLS_DEBUG=true")
|
||||
@@ -507,11 +526,11 @@ MOA_SCHEMA = {
|
||||
"properties": {
|
||||
"user_prompt": {
|
||||
"type": "string",
|
||||
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning.",
|
||||
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning."
|
||||
}
|
||||
},
|
||||
"required": ["user_prompt"],
|
||||
},
|
||||
"required": ["user_prompt"]
|
||||
}
|
||||
}
|
||||
|
||||
registry.register(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Shared OpenRouter API client for Hermes tools.
|
||||
|
||||
Provides a single lazy-initialized AsyncOpenAI client that all tool modules
|
||||
can share, eliminating the duplicated _get_openrouter_client() /
|
||||
can share, eliminating the duplicated _get_openrouter_client() /
|
||||
_get_summarizer_client() pattern previously copy-pasted across web_tools,
|
||||
vision_tools, mixture_of_agents_tool, and session_search_tool.
|
||||
"""
|
||||
@@ -9,7 +9,6 @@ vision_tools, mixture_of_agents_tool, and session_search_tool.
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
||||
@@ -20,7 +20,7 @@ V4A Format:
|
||||
|
||||
Usage:
|
||||
from tools.patch_parser import parse_v4a_patch, apply_v4a_operations
|
||||
|
||||
|
||||
operations, error = parse_v4a_patch(patch_content)
|
||||
if error:
|
||||
print(f"Parse error: {error}")
|
||||
@@ -30,8 +30,8 @@ Usage:
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Any
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class OperationType(Enum):
|
||||
@@ -44,7 +44,6 @@ class OperationType(Enum):
|
||||
@dataclass
|
||||
class HunkLine:
|
||||
"""A single line in a patch hunk."""
|
||||
|
||||
prefix: str # ' ', '-', or '+'
|
||||
content: str
|
||||
|
||||
@@ -52,174 +51,182 @@ class HunkLine:
|
||||
@dataclass
|
||||
class Hunk:
|
||||
"""A group of changes within a file."""
|
||||
|
||||
context_hint: str | None = None
|
||||
lines: list[HunkLine] = field(default_factory=list)
|
||||
context_hint: Optional[str] = None
|
||||
lines: List[HunkLine] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchOperation:
|
||||
"""A single operation in a V4A patch."""
|
||||
|
||||
operation: OperationType
|
||||
file_path: str
|
||||
new_path: str | None = None # For move operations
|
||||
hunks: list[Hunk] = field(default_factory=list)
|
||||
content: str | None = None # For add file operations
|
||||
new_path: Optional[str] = None # For move operations
|
||||
hunks: List[Hunk] = field(default_factory=list)
|
||||
content: Optional[str] = None # For add file operations
|
||||
|
||||
|
||||
def parse_v4a_patch(patch_content: str) -> tuple[list[PatchOperation], str | None]:
|
||||
def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[str]]:
|
||||
"""
|
||||
Parse a V4A format patch.
|
||||
|
||||
|
||||
Args:
|
||||
patch_content: The patch text in V4A format
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (operations, error_message)
|
||||
- If successful: (list_of_operations, None)
|
||||
- If failed: ([], error_description)
|
||||
"""
|
||||
lines = patch_content.split("\n")
|
||||
operations: list[PatchOperation] = []
|
||||
|
||||
lines = patch_content.split('\n')
|
||||
operations: List[PatchOperation] = []
|
||||
|
||||
# Find patch boundaries
|
||||
start_idx = None
|
||||
end_idx = None
|
||||
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if "*** Begin Patch" in line or "***Begin Patch" in line:
|
||||
if '*** Begin Patch' in line or '***Begin Patch' in line:
|
||||
start_idx = i
|
||||
elif "*** End Patch" in line or "***End Patch" in line:
|
||||
elif '*** End Patch' in line or '***End Patch' in line:
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
|
||||
if start_idx is None:
|
||||
# Try to parse without explicit begin marker
|
||||
start_idx = -1
|
||||
|
||||
|
||||
if end_idx is None:
|
||||
end_idx = len(lines)
|
||||
|
||||
|
||||
# Parse operations between boundaries
|
||||
i = start_idx + 1
|
||||
current_op: PatchOperation | None = None
|
||||
current_hunk: Hunk | None = None
|
||||
|
||||
current_op: Optional[PatchOperation] = None
|
||||
current_hunk: Optional[Hunk] = None
|
||||
|
||||
while i < end_idx:
|
||||
line = lines[i]
|
||||
|
||||
|
||||
# Check for file operation markers
|
||||
update_match = re.match(r"\*\*\*\s*Update\s+File:\s*(.+)", line)
|
||||
add_match = re.match(r"\*\*\*\s*Add\s+File:\s*(.+)", line)
|
||||
delete_match = re.match(r"\*\*\*\s*Delete\s+File:\s*(.+)", line)
|
||||
move_match = re.match(r"\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)", line)
|
||||
|
||||
update_match = re.match(r'\*\*\*\s*Update\s+File:\s*(.+)', line)
|
||||
add_match = re.match(r'\*\*\*\s*Add\s+File:\s*(.+)', line)
|
||||
delete_match = re.match(r'\*\*\*\s*Delete\s+File:\s*(.+)', line)
|
||||
move_match = re.match(r'\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)', line)
|
||||
|
||||
if update_match:
|
||||
# Save previous operation
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(operation=OperationType.UPDATE, file_path=update_match.group(1).strip())
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.UPDATE,
|
||||
file_path=update_match.group(1).strip()
|
||||
)
|
||||
current_hunk = None
|
||||
|
||||
|
||||
elif add_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(operation=OperationType.ADD, file_path=add_match.group(1).strip())
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.ADD,
|
||||
file_path=add_match.group(1).strip()
|
||||
)
|
||||
current_hunk = Hunk()
|
||||
|
||||
|
||||
elif delete_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(operation=OperationType.DELETE, file_path=delete_match.group(1).strip())
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.DELETE,
|
||||
file_path=delete_match.group(1).strip()
|
||||
)
|
||||
operations.append(current_op)
|
||||
current_op = None
|
||||
current_hunk = None
|
||||
|
||||
|
||||
elif move_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.MOVE,
|
||||
file_path=move_match.group(1).strip(),
|
||||
new_path=move_match.group(2).strip(),
|
||||
new_path=move_match.group(2).strip()
|
||||
)
|
||||
operations.append(current_op)
|
||||
current_op = None
|
||||
current_hunk = None
|
||||
|
||||
elif line.startswith("@@"):
|
||||
|
||||
elif line.startswith('@@'):
|
||||
# Context hint / hunk marker
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
|
||||
|
||||
# Extract context hint
|
||||
hint_match = re.match(r"@@\s*(.+?)\s*@@", line)
|
||||
hint_match = re.match(r'@@\s*(.+?)\s*@@', line)
|
||||
hint = hint_match.group(1) if hint_match else None
|
||||
current_hunk = Hunk(context_hint=hint)
|
||||
|
||||
|
||||
elif current_op and line:
|
||||
# Parse hunk line
|
||||
if current_hunk is None:
|
||||
current_hunk = Hunk()
|
||||
|
||||
if line.startswith("+"):
|
||||
current_hunk.lines.append(HunkLine("+", line[1:]))
|
||||
elif line.startswith("-"):
|
||||
current_hunk.lines.append(HunkLine("-", line[1:]))
|
||||
elif line.startswith(" "):
|
||||
current_hunk.lines.append(HunkLine(" ", line[1:]))
|
||||
elif line.startswith("\\"):
|
||||
|
||||
if line.startswith('+'):
|
||||
current_hunk.lines.append(HunkLine('+', line[1:]))
|
||||
elif line.startswith('-'):
|
||||
current_hunk.lines.append(HunkLine('-', line[1:]))
|
||||
elif line.startswith(' '):
|
||||
current_hunk.lines.append(HunkLine(' ', line[1:]))
|
||||
elif line.startswith('\\'):
|
||||
# "\ No newline at end of file" marker - skip
|
||||
pass
|
||||
else:
|
||||
# Treat as context line (implicit space prefix)
|
||||
current_hunk.lines.append(HunkLine(" ", line))
|
||||
|
||||
current_hunk.lines.append(HunkLine(' ', line))
|
||||
|
||||
i += 1
|
||||
|
||||
|
||||
# Don't forget the last operation
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
|
||||
return operations, None
|
||||
|
||||
|
||||
def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "PatchResult":
|
||||
def apply_v4a_operations(operations: List[PatchOperation],
|
||||
file_ops: Any) -> 'PatchResult':
|
||||
"""
|
||||
Apply V4A patch operations using a file operations interface.
|
||||
|
||||
|
||||
Args:
|
||||
operations: List of PatchOperation from parse_v4a_patch
|
||||
file_ops: Object with read_file, write_file methods
|
||||
|
||||
|
||||
Returns:
|
||||
PatchResult with results of all operations
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from tools.file_operations import PatchResult
|
||||
|
||||
|
||||
files_modified = []
|
||||
files_created = []
|
||||
files_deleted = []
|
||||
all_diffs = []
|
||||
errors = []
|
||||
|
||||
|
||||
for op in operations:
|
||||
try:
|
||||
if op.operation == OperationType.ADD:
|
||||
@@ -229,7 +236,7 @@ def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "Pa
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to add {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.DELETE:
|
||||
result = _apply_delete(op, file_ops)
|
||||
if result[0]:
|
||||
@@ -237,7 +244,7 @@ def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "Pa
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to delete {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.MOVE:
|
||||
result = _apply_move(op, file_ops)
|
||||
if result[0]:
|
||||
@@ -245,7 +252,7 @@ def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "Pa
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to move {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.UPDATE:
|
||||
result = _apply_update(op, file_ops)
|
||||
if result[0]:
|
||||
@@ -253,19 +260,19 @@ def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "Pa
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to update {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Error processing {op.file_path}: {str(e)}")
|
||||
|
||||
|
||||
# Run lint on all modified/created files
|
||||
lint_results = {}
|
||||
for f in files_modified + files_created:
|
||||
if hasattr(file_ops, "_check_lint"):
|
||||
if hasattr(file_ops, '_check_lint'):
|
||||
lint_result = file_ops._check_lint(f)
|
||||
lint_results[f] = lint_result.to_dict()
|
||||
|
||||
combined_diff = "\n".join(all_diffs)
|
||||
|
||||
|
||||
combined_diff = '\n'.join(all_diffs)
|
||||
|
||||
if errors:
|
||||
return PatchResult(
|
||||
success=False,
|
||||
@@ -274,124 +281,123 @@ def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "Pa
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None,
|
||||
error="; ".join(errors),
|
||||
error='; '.join(errors)
|
||||
)
|
||||
|
||||
|
||||
return PatchResult(
|
||||
success=True,
|
||||
diff=combined_diff,
|
||||
files_modified=files_modified,
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None,
|
||||
lint=lint_results if lint_results else None
|
||||
)
|
||||
|
||||
|
||||
def _apply_add(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply an add file operation."""
|
||||
# Extract content from hunks (all + lines)
|
||||
content_lines = []
|
||||
for hunk in op.hunks:
|
||||
for line in hunk.lines:
|
||||
if line.prefix == "+":
|
||||
if line.prefix == '+':
|
||||
content_lines.append(line.content)
|
||||
|
||||
content = "\n".join(content_lines)
|
||||
|
||||
|
||||
content = '\n'.join(content_lines)
|
||||
|
||||
result = file_ops.write_file(op.file_path, content)
|
||||
if result.error:
|
||||
return False, result.error
|
||||
|
||||
|
||||
diff = f"--- /dev/null\n+++ b/{op.file_path}\n"
|
||||
diff += "\n".join(f"+{line}" for line in content_lines)
|
||||
|
||||
diff += '\n'.join(f"+{line}" for line in content_lines)
|
||||
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_delete(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply a delete file operation."""
|
||||
# Read file first for diff
|
||||
read_result = file_ops.read_file(op.file_path)
|
||||
|
||||
|
||||
if read_result.error and "not found" in read_result.error.lower():
|
||||
# File doesn't exist, nothing to delete
|
||||
return True, f"# {op.file_path} already deleted or doesn't exist"
|
||||
|
||||
|
||||
# Delete directly via shell command using the underlying environment
|
||||
rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}")
|
||||
|
||||
|
||||
if rm_result.exit_code != 0:
|
||||
return False, rm_result.stdout
|
||||
|
||||
|
||||
diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_move(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply a move file operation."""
|
||||
# Use shell mv command
|
||||
mv_result = file_ops._exec(
|
||||
f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}"
|
||||
)
|
||||
|
||||
|
||||
if mv_result.exit_code != 0:
|
||||
return False, mv_result.stdout
|
||||
|
||||
|
||||
diff = f"# Moved: {op.file_path} -> {op.new_path}"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_update(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply an update file operation."""
|
||||
# Read current content
|
||||
read_result = file_ops.read_file(op.file_path, limit=10000)
|
||||
|
||||
|
||||
if read_result.error:
|
||||
return False, f"Cannot read file: {read_result.error}"
|
||||
|
||||
|
||||
# Parse content (remove line numbers)
|
||||
current_lines = []
|
||||
for line in read_result.content.split("\n"):
|
||||
if "|" in line:
|
||||
for line in read_result.content.split('\n'):
|
||||
if '|' in line:
|
||||
# Line format: " 123|content"
|
||||
parts = line.split("|", 1)
|
||||
parts = line.split('|', 1)
|
||||
if len(parts) == 2:
|
||||
current_lines.append(parts[1])
|
||||
else:
|
||||
current_lines.append(line)
|
||||
else:
|
||||
current_lines.append(line)
|
||||
|
||||
current_content = "\n".join(current_lines)
|
||||
|
||||
|
||||
current_content = '\n'.join(current_lines)
|
||||
|
||||
# Apply each hunk
|
||||
new_content = current_content
|
||||
|
||||
|
||||
for hunk in op.hunks:
|
||||
# Build search pattern from context and removed lines
|
||||
search_lines = []
|
||||
replace_lines = []
|
||||
|
||||
|
||||
for line in hunk.lines:
|
||||
if line.prefix == " ":
|
||||
if line.prefix == ' ':
|
||||
search_lines.append(line.content)
|
||||
replace_lines.append(line.content)
|
||||
elif line.prefix == "-":
|
||||
elif line.prefix == '-':
|
||||
search_lines.append(line.content)
|
||||
elif line.prefix == "+":
|
||||
elif line.prefix == '+':
|
||||
replace_lines.append(line.content)
|
||||
|
||||
|
||||
if search_lines:
|
||||
search_pattern = "\n".join(search_lines)
|
||||
replacement = "\n".join(replace_lines)
|
||||
|
||||
search_pattern = '\n'.join(search_lines)
|
||||
replacement = '\n'.join(replace_lines)
|
||||
|
||||
# Use fuzzy matching
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
new_content, count, error = fuzzy_find_and_replace(
|
||||
new_content, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
|
||||
if error and count == 0:
|
||||
# Try with context hint if available
|
||||
if hunk.context_hint:
|
||||
@@ -402,32 +408,31 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
window_start = max(0, hint_pos - 500)
|
||||
window_end = min(len(new_content), hint_pos + 2000)
|
||||
window = new_content[window_start:window_end]
|
||||
|
||||
|
||||
window_new, count, error = fuzzy_find_and_replace(
|
||||
window, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
|
||||
if count > 0:
|
||||
new_content = new_content[:window_start] + window_new + new_content[window_end:]
|
||||
error = None
|
||||
|
||||
|
||||
if error:
|
||||
return False, f"Could not apply hunk: {error}"
|
||||
|
||||
|
||||
# Write new content
|
||||
write_result = file_ops.write_file(op.file_path, new_content)
|
||||
if write_result.error:
|
||||
return False, write_result.error
|
||||
|
||||
|
||||
# Generate diff
|
||||
import difflib
|
||||
|
||||
diff_lines = difflib.unified_diff(
|
||||
current_content.splitlines(keepends=True),
|
||||
new_content.splitlines(keepends=True),
|
||||
fromfile=f"a/{op.file_path}",
|
||||
tofile=f"b/{op.file_path}",
|
||||
tofile=f"b/{op.file_path}"
|
||||
)
|
||||
diff = "".join(diff_lines)
|
||||
|
||||
diff = ''.join(diff_lines)
|
||||
|
||||
return True, diff
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user