Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2ce9edcb29 | |||
| 45c8d3da96 | |||
| 5ca6d681f0 | |||
| df806bdbaf | |||
| 0ef80c5f32 | |||
| c4cf20f564 | |||
| 68d5472810 | |||
| 252fbea005 | |||
| c774833667 | |||
| d5d22fe7ba | |||
| bf84cdfa5e | |||
| 38d694f559 | |||
| ed6427e0a7 | |||
| 0fd3b59ba1 | |||
| 6716e66e89 | |||
| d02561af85 | |||
| 8eb70a6885 | |||
| ee3d2941cc | |||
| 475205e30b | |||
| 612321631f | |||
| 83cbf7b5bb | |||
| 563101e2a9 | |||
| fe6a916284 | |||
| 57481c8ac5 | |||
| c62cadb73a | |||
| 442888a05b | |||
| b151d5f7a7 | |||
| f6db1b27ba | |||
| 0df4d1278e | |||
| 95f99ea4b9 | |||
| 811adca277 | |||
| aafe37012a | |||
| 909de72426 | |||
| ba1b600bce | |||
| fcd1645223 | |||
| 253a9adc72 | |||
| 300964178f | |||
| 7a3682ac3f | |||
| 9f01244137 | |||
| 0a80dd9c7a | |||
| 4764e06fde | |||
| 4c532c153b | |||
| a99c0478d0 | |||
| c6e3084baf | |||
| dcbdfdbb2b | |||
| 91b881f931 |
@@ -0,0 +1,13 @@
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
.gitmodules
|
||||
|
||||
# Dependencies
|
||||
node_modules
|
||||
|
||||
# CI/CD
|
||||
.github
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
+1
-1
@@ -98,7 +98,7 @@ FAL_KEY=
|
||||
HONCHO_API_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# TERMINAL TOOL CONFIGURATION (mini-swe-agent backend)
|
||||
# TERMINAL TOOL CONFIGURATION
|
||||
# =============================================================================
|
||||
# Backend type: "local", "singularity", "docker", "modal", or "ssh"
|
||||
# Terminal backend is configured in ~/.hermes/config.yaml (terminal.backend).
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
name: Docker Build and Publish
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
concurrency:
|
||||
group: docker-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
load: true
|
||||
tags: nousresearch/hermes-agent:test
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Test image starts
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v /tmp/hermes-test:/opt/data \
|
||||
--entrypoint /opt/hermes/docker/entrypoint.sh \
|
||||
nousresearch/hermes-agent:test --help
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Push image
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: true
|
||||
tags: |
|
||||
nousresearch/hermes-agent:latest
|
||||
nousresearch/hermes-agent:${{ github.sha }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
@@ -210,6 +210,10 @@ registry.register(
|
||||
|
||||
The registry handles schema collection, dispatch, availability checking, and error wrapping. All handlers MUST return a JSON string.
|
||||
|
||||
**Path references in tool schemas**: If the schema description mentions file paths (e.g. default output directories), use `display_hermes_home()` to make them profile-aware. The schema is generated at import time, which is after `_apply_profile_override()` sets `HERMES_HOME`.
|
||||
|
||||
**State files**: If a tool stores persistent state (caches, logs, checkpoints), use `get_hermes_home()` for the base directory — never `Path.home() / ".hermes"`. This ensures each profile gets its own state.
|
||||
|
||||
**Agent-level tools** (todo, memory): intercepted by `run_agent.py` before `handle_function_call()`. See `todo_tool.py` for the pattern.
|
||||
|
||||
---
|
||||
@@ -358,8 +362,69 @@ in config.yaml (or `HERMES_BACKGROUND_NOTIFICATIONS` env var):
|
||||
|
||||
---
|
||||
|
||||
## Profiles: Multi-Instance Support
|
||||
|
||||
Hermes supports **profiles** — multiple fully isolated instances, each with its own
|
||||
`HERMES_HOME` directory (config, API keys, memory, sessions, skills, gateway, etc.).
|
||||
|
||||
The core mechanism: `_apply_profile_override()` in `hermes_cli/main.py` sets
|
||||
`HERMES_HOME` before any module imports. All 119+ references to `get_hermes_home()`
|
||||
automatically scope to the active profile.
|
||||
|
||||
### Rules for profile-safe code
|
||||
|
||||
1. **Use `get_hermes_home()` for all HERMES_HOME paths.** Import from `hermes_constants`.
|
||||
NEVER hardcode `~/.hermes` or `Path.home() / ".hermes"` in code that reads/writes state.
|
||||
```python
|
||||
# GOOD
|
||||
from hermes_constants import get_hermes_home
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
|
||||
# BAD — breaks profiles
|
||||
config_path = Path.home() / ".hermes" / "config.yaml"
|
||||
```
|
||||
|
||||
2. **Use `display_hermes_home()` for user-facing messages.** Import from `hermes_constants`.
|
||||
This returns `~/.hermes` for default or `~/.hermes/profiles/<name>` for profiles.
|
||||
```python
|
||||
# GOOD
|
||||
from hermes_constants import display_hermes_home
|
||||
print(f"Config saved to {display_hermes_home()}/config.yaml")
|
||||
|
||||
# BAD — shows wrong path for profiles
|
||||
print("Config saved to ~/.hermes/config.yaml")
|
||||
```
|
||||
|
||||
3. **Module-level constants are fine** — they cache `get_hermes_home()` at import time,
|
||||
which is AFTER `_apply_profile_override()` sets the env var. Just use `get_hermes_home()`,
|
||||
not `Path.home() / ".hermes"`.
|
||||
|
||||
4. **Tests that mock `Path.home()` must also set `HERMES_HOME`** — since code now uses
|
||||
`get_hermes_home()` (reads env var), not `Path.home() / ".hermes"`:
|
||||
```python
|
||||
with patch.object(Path, "home", return_value=tmp_path), \
|
||||
patch.dict(os.environ, {"HERMES_HOME": str(tmp_path / ".hermes")}):
|
||||
...
|
||||
```
|
||||
|
||||
5. **Gateway platform adapters should use token locks** — if the adapter connects with
|
||||
a unique credential (bot token, API key), call `acquire_scoped_lock()` from
|
||||
`gateway.status` in the `connect()`/`start()` method and `release_scoped_lock()` in
|
||||
`disconnect()`/`stop()`. This prevents two profiles from using the same credential.
|
||||
See `gateway/platforms/telegram.py` for the canonical pattern.
|
||||
|
||||
6. **Profile operations are HOME-anchored, not HERMES_HOME-anchored** — `_get_profiles_root()`
|
||||
returns `Path.home() / ".hermes" / "profiles"`, NOT `get_hermes_home() / "profiles"`.
|
||||
This is intentional — it lets `hermes -p coder profile list` see all profiles regardless
|
||||
of which one is active.
|
||||
|
||||
## Known Pitfalls
|
||||
|
||||
### DO NOT hardcode `~/.hermes` paths
|
||||
Use `get_hermes_home()` from `hermes_constants` for code paths. Use `display_hermes_home()`
|
||||
for user-facing print/log messages. Hardcoding `~/.hermes` breaks profiles — each profile
|
||||
has its own `HERMES_HOME` directory. This was the source of 5 bugs fixed in PR #3575.
|
||||
|
||||
### DO NOT use `simple_term_menu` for interactive menus
|
||||
Rendering bugs in tmux/iTerm2 — ghosting on scroll. Use `curses` (stdlib) instead. See `hermes_cli/tools_config.py` for the pattern.
|
||||
|
||||
@@ -375,6 +440,19 @@ Tool schema descriptions must not mention tools from other toolsets by name (e.g
|
||||
### Tests must not write to `~/.hermes/`
|
||||
The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HERMES_HOME` to a temp dir. Never hardcode `~/.hermes/` paths in tests.
|
||||
|
||||
**Profile tests**: When testing profile features, also mock `Path.home()` so that
|
||||
`_get_profiles_root()` and `_get_default_hermes_home()` resolve within the temp dir.
|
||||
Use the pattern from `tests/hermes_cli/test_profiles.py`:
|
||||
```python
|
||||
@pytest.fixture
|
||||
def profile_env(tmp_path, monkeypatch):
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
return home
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
+20
@@ -0,0 +1,20 @@
|
||||
FROM debian:13.4
|
||||
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y nodejs npm python3 python3-pip ripgrep ffmpeg gcc python3-dev libffi-dev
|
||||
|
||||
COPY . /opt/hermes
|
||||
WORKDIR /opt/hermes
|
||||
|
||||
RUN pip install -e ".[all]" --break-system-packages
|
||||
RUN npm install
|
||||
RUN npx playwright install --with-deps chromium
|
||||
WORKDIR /opt/hermes/scripts/whatsapp-bridge
|
||||
RUN npm install
|
||||
|
||||
WORKDIR /opt/hermes
|
||||
RUN chmod +x /opt/hermes/docker/entrypoint.sh
|
||||
|
||||
ENV HERMES_HOME=/opt/data
|
||||
VOLUME [ "/opt/data" ]
|
||||
ENTRYPOINT [ "/opt/hermes/docker/entrypoint.sh" ]
|
||||
@@ -74,7 +74,7 @@ def main() -> None:
|
||||
|
||||
agent = HermesACPAgent()
|
||||
try:
|
||||
asyncio.run(acp.run_agent(agent))
|
||||
asyncio.run(acp.run_agent(agent, use_unstable_protocol=True))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutting down (KeyboardInterrupt)")
|
||||
except Exception:
|
||||
|
||||
+46
-3
@@ -25,6 +25,9 @@ from acp.schema import (
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
SetSessionConfigOptionResponse,
|
||||
SetSessionModelResponse,
|
||||
SetSessionModeResponse,
|
||||
ResourceContentBlock,
|
||||
SessionCapabilities,
|
||||
SessionForkCapabilities,
|
||||
@@ -94,11 +97,14 @@ class HermesACPAgent(acp.Agent):
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
protocol_version: int,
|
||||
protocol_version: int | None = None,
|
||||
client_capabilities: ClientCapabilities | None = None,
|
||||
client_info: Implementation | None = None,
|
||||
**kwargs: Any,
|
||||
) -> InitializeResponse:
|
||||
resolved_protocol_version = (
|
||||
protocol_version if isinstance(protocol_version, int) else acp.PROTOCOL_VERSION
|
||||
)
|
||||
provider = detect_provider()
|
||||
auth_methods = None
|
||||
if provider:
|
||||
@@ -111,7 +117,11 @@ class HermesACPAgent(acp.Agent):
|
||||
]
|
||||
|
||||
client_name = client_info.name if client_info else "unknown"
|
||||
logger.info("Initialize from %s (protocol v%s)", client_name, protocol_version)
|
||||
logger.info(
|
||||
"Initialize from %s (protocol v%s)",
|
||||
client_name,
|
||||
resolved_protocol_version,
|
||||
)
|
||||
|
||||
return InitializeResponse(
|
||||
protocol_version=acp.PROTOCOL_VERSION,
|
||||
@@ -471,7 +481,7 @@ class HermesACPAgent(acp.Agent):
|
||||
|
||||
async def set_session_model(
|
||||
self, model_id: str, session_id: str, **kwargs: Any
|
||||
):
|
||||
) -> SetSessionModelResponse | None:
|
||||
"""Switch the model for a session (called by ACP protocol)."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state:
|
||||
@@ -489,4 +499,37 @@ class HermesACPAgent(acp.Agent):
|
||||
)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||
return SetSessionModelResponse()
|
||||
logger.warning("Session %s: model switch requested for missing session", session_id)
|
||||
return None
|
||||
|
||||
async def set_session_mode(
|
||||
self, mode_id: str, session_id: str, **kwargs: Any
|
||||
) -> SetSessionModeResponse | None:
|
||||
"""Persist the editor-requested mode so ACP clients do not fail on mode switches."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state is None:
|
||||
logger.warning("Session %s: mode switch requested for missing session", session_id)
|
||||
return None
|
||||
setattr(state, "mode", mode_id)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: mode switched to %s", session_id, mode_id)
|
||||
return SetSessionModeResponse()
|
||||
|
||||
async def set_config_option(
|
||||
self, config_id: str, session_id: str, value: str, **kwargs: Any
|
||||
) -> SetSessionConfigOptionResponse | None:
|
||||
"""Accept ACP config option updates even when Hermes has no typed ACP config surface yet."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state is None:
|
||||
logger.warning("Session %s: config update requested for missing session", session_id)
|
||||
return None
|
||||
|
||||
options = getattr(state, "config_options", None)
|
||||
if not isinstance(options, dict):
|
||||
options = {}
|
||||
options[str(config_id)] = value
|
||||
setattr(state, "config_options", options)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: config option %s updated", session_id, config_id)
|
||||
return SetSessionConfigOptionResponse(config_options=[])
|
||||
|
||||
+59
-1
@@ -18,6 +18,7 @@ from typing import Optional
|
||||
from agent.skill_utils import (
|
||||
extract_skill_conditions,
|
||||
extract_skill_description,
|
||||
get_all_skills_dirs,
|
||||
get_disabled_skill_names,
|
||||
iter_skill_index_files,
|
||||
parse_frontmatter,
|
||||
@@ -444,16 +445,23 @@ def build_skills_system_prompt(
|
||||
mtime/size manifest — survives process restarts
|
||||
|
||||
Falls back to a full filesystem scan when both layers miss.
|
||||
|
||||
External skill directories (``skills.external_dirs`` in config.yaml) are
|
||||
scanned alongside the local ``~/.hermes/skills/`` directory. External dirs
|
||||
are read-only — they appear in the index but new skills are always created
|
||||
in the local dir. Local skills take precedence when names collide.
|
||||
"""
|
||||
hermes_home = get_hermes_home()
|
||||
skills_dir = hermes_home / "skills"
|
||||
external_dirs = get_all_skills_dirs()[1:] # skip local (index 0)
|
||||
|
||||
if not skills_dir.exists():
|
||||
if not skills_dir.exists() and not external_dirs:
|
||||
return ""
|
||||
|
||||
# ── Layer 1: in-process LRU cache ─────────────────────────────────
|
||||
cache_key = (
|
||||
str(skills_dir.resolve()),
|
||||
tuple(str(d) for d in external_dirs),
|
||||
tuple(sorted(str(t) for t in (available_tools or set()))),
|
||||
tuple(sorted(str(ts) for ts in (available_toolsets or set()))),
|
||||
)
|
||||
@@ -540,6 +548,56 @@ def build_skills_system_prompt(
|
||||
category_descriptions,
|
||||
)
|
||||
|
||||
# ── External skill directories ─────────────────────────────────────
|
||||
# Scan external dirs directly (no snapshot caching — they're read-only
|
||||
# and typically small). Local skills already in skills_by_category take
|
||||
# precedence: we track seen names and skip duplicates from external dirs.
|
||||
seen_skill_names: set[str] = set()
|
||||
for cat_skills in skills_by_category.values():
|
||||
for name, _desc in cat_skills:
|
||||
seen_skill_names.add(name)
|
||||
|
||||
for ext_dir in external_dirs:
|
||||
if not ext_dir.exists():
|
||||
continue
|
||||
for skill_file in iter_skill_index_files(ext_dir, "SKILL.md"):
|
||||
try:
|
||||
is_compatible, frontmatter, desc = _parse_skill_file(skill_file)
|
||||
if not is_compatible:
|
||||
continue
|
||||
entry = _build_snapshot_entry(skill_file, ext_dir, frontmatter, desc)
|
||||
skill_name = entry["skill_name"]
|
||||
if skill_name in seen_skill_names:
|
||||
continue
|
||||
if entry["frontmatter_name"] in disabled or skill_name in disabled:
|
||||
continue
|
||||
if not _skill_should_show(
|
||||
extract_skill_conditions(frontmatter),
|
||||
available_tools,
|
||||
available_toolsets,
|
||||
):
|
||||
continue
|
||||
seen_skill_names.add(skill_name)
|
||||
skills_by_category.setdefault(entry["category"], []).append(
|
||||
(skill_name, entry["description"])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Error reading external skill %s: %s", skill_file, e)
|
||||
|
||||
# External category descriptions
|
||||
for desc_file in iter_skill_index_files(ext_dir, "DESCRIPTION.md"):
|
||||
try:
|
||||
content = desc_file.read_text(encoding="utf-8")
|
||||
fm, _ = parse_frontmatter(content)
|
||||
cat_desc = fm.get("description")
|
||||
if not cat_desc:
|
||||
continue
|
||||
rel = desc_file.relative_to(ext_dir)
|
||||
cat = "/".join(rel.parts[:-1]) if len(rel.parts) > 1 else "general"
|
||||
category_descriptions.setdefault(cat, str(cat_desc).strip().strip("'\""))
|
||||
except Exception as e:
|
||||
logger.debug("Could not read external skill description %s: %s", desc_file, e)
|
||||
|
||||
if not skills_by_category:
|
||||
result = ""
|
||||
else:
|
||||
|
||||
+45
-30
@@ -128,7 +128,11 @@ def _build_skill_message(
|
||||
supporting.append(rel)
|
||||
|
||||
if supporting and skill_dir:
|
||||
skill_view_target = str(skill_dir.relative_to(SKILLS_DIR))
|
||||
try:
|
||||
skill_view_target = str(skill_dir.relative_to(SKILLS_DIR))
|
||||
except ValueError:
|
||||
# Skill is from an external dir — use the skill name instead
|
||||
skill_view_target = skill_dir.name
|
||||
parts.append("")
|
||||
parts.append("[This skill has supporting files you can load with the skill_view tool:]")
|
||||
for sf in supporting:
|
||||
@@ -158,38 +162,49 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
_skill_commands = {}
|
||||
try:
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform, _get_disabled_skill_names
|
||||
if not SKILLS_DIR.exists():
|
||||
return _skill_commands
|
||||
from agent.skill_utils import get_external_skills_dirs
|
||||
disabled = _get_disabled_skill_names()
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
if any(part in ('.git', '.github', '.hub') for part in skill_md.parts):
|
||||
continue
|
||||
try:
|
||||
content = skill_md.read_text(encoding='utf-8')
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not skill_matches_platform(frontmatter):
|
||||
seen_names: set = set()
|
||||
|
||||
# Scan local dir first, then external dirs
|
||||
dirs_to_scan = []
|
||||
if SKILLS_DIR.exists():
|
||||
dirs_to_scan.append(SKILLS_DIR)
|
||||
dirs_to_scan.extend(get_external_skills_dirs())
|
||||
|
||||
for scan_dir in dirs_to_scan:
|
||||
for skill_md in scan_dir.rglob("SKILL.md"):
|
||||
if any(part in ('.git', '.github', '.hub') for part in skill_md.parts):
|
||||
continue
|
||||
name = frontmatter.get('name', skill_md.parent.name)
|
||||
# Respect user's disabled skills config
|
||||
if name in disabled:
|
||||
try:
|
||||
content = skill_md.read_text(encoding='utf-8')
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
name = frontmatter.get('name', skill_md.parent.name)
|
||||
if name in seen_names:
|
||||
continue
|
||||
# Respect user's disabled skills config
|
||||
if name in disabled:
|
||||
continue
|
||||
description = frontmatter.get('description', '')
|
||||
if not description:
|
||||
for line in body.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
description = line[:80]
|
||||
break
|
||||
seen_names.add(name)
|
||||
cmd_name = name.lower().replace(' ', '-').replace('_', '-')
|
||||
_skill_commands[f"/{cmd_name}"] = {
|
||||
"name": name,
|
||||
"description": description or f"Invoke the {name} skill",
|
||||
"skill_md_path": str(skill_md),
|
||||
"skill_dir": str(skill_md.parent),
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
description = frontmatter.get('description', '')
|
||||
if not description:
|
||||
for line in body.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
description = line[:80]
|
||||
break
|
||||
cmd_name = name.lower().replace(' ', '-').replace('_', '-')
|
||||
_skill_commands[f"/{cmd_name}"] = {
|
||||
"name": name,
|
||||
"description": description or f"Invoke the {name} skill",
|
||||
"skill_md_path": str(skill_md),
|
||||
"skill_dir": str(skill_md.parent),
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
return _skill_commands
|
||||
|
||||
@@ -158,6 +158,73 @@ def _normalize_string_set(values) -> Set[str]:
|
||||
return {str(v).strip() for v in values if str(v).strip()}
|
||||
|
||||
|
||||
# ── External skills directories ──────────────────────────────────────────
|
||||
|
||||
|
||||
def get_external_skills_dirs() -> List[Path]:
|
||||
"""Read ``skills.external_dirs`` from config.yaml and return validated paths.
|
||||
|
||||
Each entry is expanded (``~`` and ``${VAR}``) and resolved to an absolute
|
||||
path. Only directories that actually exist are returned. Duplicates and
|
||||
paths that resolve to the local ``~/.hermes/skills/`` are silently skipped.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
if not config_path.exists():
|
||||
return []
|
||||
try:
|
||||
parsed = yaml_load(config_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return []
|
||||
if not isinstance(parsed, dict):
|
||||
return []
|
||||
|
||||
skills_cfg = parsed.get("skills")
|
||||
if not isinstance(skills_cfg, dict):
|
||||
return []
|
||||
|
||||
raw_dirs = skills_cfg.get("external_dirs")
|
||||
if not raw_dirs:
|
||||
return []
|
||||
if isinstance(raw_dirs, str):
|
||||
raw_dirs = [raw_dirs]
|
||||
if not isinstance(raw_dirs, list):
|
||||
return []
|
||||
|
||||
local_skills = (get_hermes_home() / "skills").resolve()
|
||||
seen: Set[Path] = set()
|
||||
result: List[Path] = []
|
||||
|
||||
for entry in raw_dirs:
|
||||
entry = str(entry).strip()
|
||||
if not entry:
|
||||
continue
|
||||
# Expand ~ and environment variables
|
||||
expanded = os.path.expanduser(os.path.expandvars(entry))
|
||||
p = Path(expanded).resolve()
|
||||
if p == local_skills:
|
||||
continue
|
||||
if p in seen:
|
||||
continue
|
||||
if p.is_dir():
|
||||
seen.add(p)
|
||||
result.append(p)
|
||||
else:
|
||||
logger.debug("External skills dir does not exist, skipping: %s", p)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_all_skills_dirs() -> List[Path]:
|
||||
"""Return all skill directories: local ``~/.hermes/skills/`` first, then external.
|
||||
|
||||
The local dir is always first (and always included even if it doesn't exist
|
||||
yet — callers handle that). External dirs follow in config order.
|
||||
"""
|
||||
dirs = [get_hermes_home() / "skills"]
|
||||
dirs.extend(get_external_skills_dirs())
|
||||
return dirs
|
||||
|
||||
|
||||
# ── Condition extraction ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -402,6 +402,15 @@ skills:
|
||||
# Set to 0 to disable.
|
||||
creation_nudge_interval: 15
|
||||
|
||||
# External skill directories — share skills across tools/agents without
|
||||
# copying them into ~/.hermes/skills/. Each path is expanded (~ and ${VAR})
|
||||
# and resolved to an absolute path. External dirs are read-only: skill
|
||||
# creation always writes to ~/.hermes/skills/. Local skills take precedence
|
||||
# when names collide.
|
||||
# external_dirs:
|
||||
# - ~/.agents/skills
|
||||
# - /home/shared/team-skills
|
||||
|
||||
# =============================================================================
|
||||
# Agent Behavior
|
||||
# =============================================================================
|
||||
|
||||
@@ -70,7 +70,7 @@ _COMMAND_SPINNER_FRAMES = ("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
from hermes_constants import get_hermes_home, OPENROUTER_BASE_URL
|
||||
from hermes_constants import get_hermes_home, display_hermes_home, OPENROUTER_BASE_URL
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
_hermes_home = get_hermes_home()
|
||||
@@ -1182,9 +1182,13 @@ class HermesCLI:
|
||||
self._provider_require_params = pr.get("require_parameters", False)
|
||||
self._provider_data_collection = pr.get("data_collection")
|
||||
|
||||
# Fallback model config — tried when primary provider fails after retries
|
||||
fb = CLI_CONFIG.get("fallback_model") or {}
|
||||
self._fallback_model = fb if fb.get("provider") and fb.get("model") else None
|
||||
# Fallback provider chain — tried in order when primary fails after retries.
|
||||
# Supports new list format (fallback_providers) and legacy single-dict (fallback_model).
|
||||
fb = CLI_CONFIG.get("fallback_providers") or CLI_CONFIG.get("fallback_model") or []
|
||||
# Normalize legacy single-dict to a one-element list
|
||||
if isinstance(fb, dict):
|
||||
fb = [fb] if fb.get("provider") and fb.get("model") else []
|
||||
self._fallback_model = fb
|
||||
|
||||
# Optional cheap-vs-strong routing for simple turns
|
||||
self._smart_model_routing = CLI_CONFIG.get("smart_model_routing", {}) or {}
|
||||
@@ -3594,7 +3598,7 @@ class HermesCLI:
|
||||
print(" To start the gateway:")
|
||||
print(" python cli.py --gateway")
|
||||
print()
|
||||
print(" Configuration file: ~/.hermes/config.yaml")
|
||||
print(f" Configuration file: {display_hermes_home()}/config.yaml")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
@@ -3604,7 +3608,7 @@ class HermesCLI:
|
||||
print(" 1. Set environment variables:")
|
||||
print(" TELEGRAM_BOT_TOKEN=your_token")
|
||||
print(" DISCORD_BOT_TOKEN=your_token")
|
||||
print(" 2. Or configure settings in ~/.hermes/config.yaml")
|
||||
print(f" 2. Or configure settings in {display_hermes_home()}/config.yaml")
|
||||
print()
|
||||
|
||||
def process_command(self, command: str) -> bool:
|
||||
@@ -3811,7 +3815,7 @@ class HermesCLI:
|
||||
plugins = mgr.list_plugins()
|
||||
if not plugins:
|
||||
print("No plugins installed.")
|
||||
print("Drop plugin directories into ~/.hermes/plugins/ to get started.")
|
||||
print(f"Drop plugin directories into {display_hermes_home()}/plugins/ to get started.")
|
||||
else:
|
||||
print(f"Plugins ({len(plugins)}):")
|
||||
for p in plugins:
|
||||
@@ -4340,7 +4344,7 @@ class HermesCLI:
|
||||
source = f" ({s['source']})" if s["source"] == "user" else ""
|
||||
print(f" {marker} {s['name']}{source} — {s['description']}")
|
||||
print("\n Usage: /skin <name>")
|
||||
print(" Custom skins: drop a YAML file in ~/.hermes/skins/\n")
|
||||
print(f" Custom skins: drop a YAML file in {display_hermes_home()}/skins/\n")
|
||||
return
|
||||
|
||||
new_skin = parts[1].strip().lower()
|
||||
@@ -5944,6 +5948,9 @@ class HermesCLI:
|
||||
``normal_prompt`` is the full ``branding.prompt_symbol``.
|
||||
``state_suffix`` is what special states (sudo/secret/approval/agent)
|
||||
should render after their leading icon.
|
||||
|
||||
When a profile is active (not "default"), the profile name is
|
||||
prepended to the prompt symbol: ``coder ❯`` instead of ``❯``.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.skin_engine import get_active_prompt_symbol
|
||||
@@ -5952,6 +5959,15 @@ class HermesCLI:
|
||||
symbol = "❯ "
|
||||
|
||||
symbol = (symbol or "❯ ").rstrip() + " "
|
||||
|
||||
# Prepend profile name when not default
|
||||
try:
|
||||
from hermes_cli.profiles import get_active_profile_name
|
||||
profile = get_active_profile_name()
|
||||
if profile not in ("default", "custom"):
|
||||
symbol = f"{profile} {symbol}"
|
||||
except Exception:
|
||||
pass
|
||||
stripped = symbol.rstrip()
|
||||
if not stripped:
|
||||
return "❯ ", "❯ "
|
||||
@@ -6488,6 +6504,24 @@ class HermesCLI:
|
||||
self._should_exit = True
|
||||
event.app.exit()
|
||||
|
||||
@kb.add('c-z')
|
||||
def handle_ctrl_z(event):
|
||||
"""Handle Ctrl+Z - suspend process to background (Unix only)."""
|
||||
import sys
|
||||
if sys.platform == 'win32':
|
||||
_cprint(f"\n{_DIM}Suspend (Ctrl+Z) is not supported on Windows.{_RST}")
|
||||
event.app.invalidate()
|
||||
return
|
||||
import os, signal as _sig
|
||||
from prompt_toolkit.application import run_in_terminal
|
||||
from hermes_cli.skin_engine import get_active_skin
|
||||
agent_name = get_active_skin().get_branding("agent_name", "Hermes Agent")
|
||||
msg = f"\n{agent_name} has been suspended. Run `fg` to bring {agent_name} back."
|
||||
def _suspend():
|
||||
os.write(1, msg.encode())
|
||||
os.kill(0, _sig.SIGTSTP)
|
||||
run_in_terminal(_suspend)
|
||||
|
||||
# Voice push-to-talk key: configurable via config.yaml (voice.record_key)
|
||||
# Default: Ctrl+B (avoids conflict with Ctrl+R readline reverse-search)
|
||||
# Config uses "ctrl+b" format; prompt_toolkit expects "c-b" format.
|
||||
|
||||
+23
-11
@@ -26,6 +26,7 @@ except ImportError:
|
||||
msvcrt = None
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_cli.config import load_config
|
||||
from typing import Optional
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
@@ -164,18 +165,29 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
logger.warning("Job '%s': platform '%s' not configured/enabled", job["id"], platform_name)
|
||||
return
|
||||
|
||||
# Wrap the content so the user knows this is a cron delivery and that
|
||||
# the interactive agent has no visibility into it.
|
||||
task_name = job.get("name", job["id"])
|
||||
wrapped = (
|
||||
f"Cronjob Response: {task_name}\n"
|
||||
f"-------------\n\n"
|
||||
f"{content}\n\n"
|
||||
f"Note: The agent cannot see this message, and therefore cannot respond to it."
|
||||
)
|
||||
# Optionally wrap the content with a header/footer so the user knows this
|
||||
# is a cron delivery. Wrapping is on by default; set cron.wrap_response: false
|
||||
# in config.yaml for clean output.
|
||||
wrap_response = True
|
||||
try:
|
||||
user_cfg = load_config()
|
||||
wrap_response = user_cfg.get("cron", {}).get("wrap_response", True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if wrap_response:
|
||||
task_name = job.get("name", job["id"])
|
||||
delivery_content = (
|
||||
f"Cronjob Response: {task_name}\n"
|
||||
f"-------------\n\n"
|
||||
f"{content}\n\n"
|
||||
f"Note: The agent cannot see this message, and therefore cannot respond to it."
|
||||
)
|
||||
else:
|
||||
delivery_content = content
|
||||
|
||||
# Run the async send in a fresh event loop (safe from any thread)
|
||||
coro = _send_to_platform(platform, pconfig, chat_id, wrapped, thread_id=thread_id)
|
||||
coro = _send_to_platform(platform, pconfig, chat_id, delivery_content, thread_id=thread_id)
|
||||
try:
|
||||
result = asyncio.run(coro)
|
||||
except RuntimeError:
|
||||
@@ -186,7 +198,7 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
coro.close()
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, wrapped, thread_id=thread_id))
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, delivery_content, thread_id=thread_id))
|
||||
result = future.result(timeout=30)
|
||||
except Exception as e:
|
||||
logger.error("Job '%s': delivery to %s:%s failed: %s", job["id"], platform_name, chat_id, e)
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
# Hermes Agent Persona
|
||||
|
||||
<!--
|
||||
This file defines the agent's personality and tone.
|
||||
The agent will embody whatever you write here.
|
||||
Edit this to customize how Hermes communicates with you.
|
||||
|
||||
Examples:
|
||||
- "You are a warm, playful assistant who uses kaomoji occasionally."
|
||||
- "You are a concise technical expert. No fluff, just facts."
|
||||
- "You speak like a friendly coworker who happens to know everything."
|
||||
|
||||
This file is loaded fresh each message -- no restart needed.
|
||||
Delete the contents (or this file) to use the default personality.
|
||||
-->
|
||||
@@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
# Docker entrypoint: bootstrap config files into the mounted volume, then run hermes.
|
||||
set -e
|
||||
|
||||
HERMES_HOME="/opt/data"
|
||||
INSTALL_DIR="/opt/hermes"
|
||||
|
||||
# Create essential directory structure. Cache and platform directories
|
||||
# (cache/images, cache/audio, platforms/whatsapp, etc.) are created on
|
||||
# demand by the application — don't pre-create them here so new installs
|
||||
# get the consolidated layout from get_hermes_dir().
|
||||
mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills}
|
||||
|
||||
# .env
|
||||
if [ ! -f "$HERMES_HOME/.env" ]; then
|
||||
cp "$INSTALL_DIR/.env.example" "$HERMES_HOME/.env"
|
||||
fi
|
||||
|
||||
# config.yaml
|
||||
if [ ! -f "$HERMES_HOME/config.yaml" ]; then
|
||||
cp "$INSTALL_DIR/cli-config.yaml.example" "$HERMES_HOME/config.yaml"
|
||||
fi
|
||||
|
||||
# SOUL.md
|
||||
if [ ! -f "$HERMES_HOME/SOUL.md" ]; then
|
||||
cp "$INSTALL_DIR/docker/SOUL.md" "$HERMES_HOME/SOUL.md"
|
||||
fi
|
||||
|
||||
# Sync bundled skills (manifest-based so user edits are preserved)
|
||||
if [ -d "$INSTALL_DIR/skills" ]; then
|
||||
python3 "$INSTALL_DIR/tools/skills_sync.py"
|
||||
fi
|
||||
|
||||
exec hermes "$@"
|
||||
@@ -209,7 +209,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
|
||||
# Agent settings -- TB2 tasks are complex, need many turns
|
||||
max_agent_turns=60,
|
||||
max_token_length=***
|
||||
max_token_length=16000,
|
||||
agent_temperature=0.6,
|
||||
system_prompt=None,
|
||||
|
||||
@@ -233,7 +233,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
|
||||
tokenizer_name="NousRe...1-8B",
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-bench-2",
|
||||
ensure_scores_are_not_same=False, # Binary rewards may all be 0 or 1
|
||||
@@ -245,7 +245,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4",
|
||||
server_type="openai",
|
||||
api_key=os.get...EY", ""),
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
@@ -513,3 +513,446 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
reward = 0.0
|
||||
else:
|
||||
# Run tests in a thread so the blocking ctx.terminal() calls
|
||||
# don't freeze the entire event loop (which would stall all
|
||||
# other tasks, tqdm updates, and timeout timers).
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
reward = await loop.run_in_executor(
|
||||
None, # default thread pool
|
||||
self._run_tests, eval_item, ctx, task_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Task %s: test verification failed: %s", task_name, e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
passed = reward == 1.0
|
||||
status = "PASS" if passed else "FAIL"
|
||||
elapsed = time.time() - task_start
|
||||
tqdm.write(f" [{status}] {task_name} (turns={result.turns_used}, {elapsed:.0f}s)")
|
||||
logger.info(
|
||||
"Task %s: reward=%.1f, turns=%d, finished=%s",
|
||||
task_name, reward, result.turns_used, result.finished_naturally,
|
||||
)
|
||||
|
||||
out = {
|
||||
"passed": passed,
|
||||
"reward": reward,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"messages": result.messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - task_start
|
||||
logger.error("Task %s: rollout failed: %s", task_name, e, exc_info=True)
|
||||
tqdm.write(f" [ERROR] {task_name}: {e} ({elapsed:.0f}s)")
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": str(e),
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
finally:
|
||||
# --- Cleanup: clear overrides, sandbox, and temp files ---
|
||||
clear_task_env_overrides(task_id)
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
except Exception as e:
|
||||
logger.debug("VM cleanup for %s: %s", task_id[:8], e)
|
||||
if task_dir and task_dir.exists():
|
||||
shutil.rmtree(task_dir, ignore_errors=True)
|
||||
|
||||
def _run_tests(
|
||||
self, item: Dict[str, Any], ctx: ToolContext, task_name: str
|
||||
) -> float:
|
||||
"""
|
||||
Upload and execute the test suite in the agent's sandbox, then
|
||||
download the verifier output locally to read the reward.
|
||||
|
||||
Follows Harbor's verification pattern:
|
||||
1. Upload tests/ directory into the sandbox
|
||||
2. Execute test.sh inside the sandbox
|
||||
3. Download /logs/verifier/ directory to a local temp dir
|
||||
4. Read reward.txt locally with native Python I/O
|
||||
|
||||
Downloading locally avoids issues with the file_read tool on
|
||||
the Modal VM and matches how Harbor handles verification.
|
||||
|
||||
TB2 test scripts (test.sh) typically:
|
||||
1. Install pytest via uv/pip
|
||||
2. Run pytest against the test files in /tests/
|
||||
3. Write results to /logs/verifier/reward.txt
|
||||
|
||||
Args:
|
||||
item: The TB2 task dict (contains tests_tar, test_sh)
|
||||
ctx: ToolContext scoped to this task's sandbox
|
||||
task_name: For logging
|
||||
|
||||
Returns:
|
||||
1.0 if tests pass, 0.0 otherwise
|
||||
"""
|
||||
tests_tar = item.get("tests_tar", "")
|
||||
test_sh = item.get("test_sh", "")
|
||||
|
||||
if not test_sh:
|
||||
logger.warning("Task %s: no test_sh content, reward=0", task_name)
|
||||
return 0.0
|
||||
|
||||
# Create required directories in the sandbox
|
||||
ctx.terminal("mkdir -p /tests /logs/verifier")
|
||||
|
||||
# Upload test files into the sandbox (binary-safe via base64)
|
||||
if tests_tar:
|
||||
tests_temp = Path(tempfile.mkdtemp(prefix=f"tb2-tests-{task_name}-"))
|
||||
try:
|
||||
_extract_base64_tar(tests_tar, tests_temp)
|
||||
ctx.upload_dir(str(tests_temp), "/tests")
|
||||
except Exception as e:
|
||||
logger.warning("Task %s: failed to upload test files: %s", task_name, e)
|
||||
finally:
|
||||
shutil.rmtree(tests_temp, ignore_errors=True)
|
||||
|
||||
# Write the test runner script (test.sh)
|
||||
ctx.write_file("/tests/test.sh", test_sh)
|
||||
ctx.terminal("chmod +x /tests/test.sh")
|
||||
|
||||
# Execute the test suite
|
||||
logger.info(
|
||||
"Task %s: running test suite (timeout=%ds)",
|
||||
task_name, self.config.test_timeout,
|
||||
)
|
||||
test_result = ctx.terminal(
|
||||
"bash /tests/test.sh",
|
||||
timeout=self.config.test_timeout,
|
||||
)
|
||||
|
||||
exit_code = test_result.get("exit_code", -1)
|
||||
output = test_result.get("output", "")
|
||||
|
||||
# Download the verifier output directory locally, then read reward.txt
|
||||
# with native Python I/O. This avoids issues with file_read on the
|
||||
# Modal VM and matches Harbor's verification pattern.
|
||||
reward = 0.0
|
||||
local_verifier_dir = Path(tempfile.mkdtemp(prefix=f"tb2-verifier-{task_name}-"))
|
||||
try:
|
||||
ctx.download_dir("/logs/verifier", str(local_verifier_dir))
|
||||
|
||||
reward_file = local_verifier_dir / "reward.txt"
|
||||
if reward_file.exists() and reward_file.stat().st_size > 0:
|
||||
content = reward_file.read_text().strip()
|
||||
if content == "1":
|
||||
reward = 1.0
|
||||
elif content == "0":
|
||||
reward = 0.0
|
||||
else:
|
||||
# Unexpected content -- try parsing as float
|
||||
try:
|
||||
reward = float(content)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Task %s: reward.txt content unexpected (%r), "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, content, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
else:
|
||||
# reward.txt not written -- fall back to exit code
|
||||
logger.warning(
|
||||
"Task %s: reward.txt not found after download, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Task %s: failed to download verifier dir: %s, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, e, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
finally:
|
||||
shutil.rmtree(local_verifier_dir, ignore_errors=True)
|
||||
|
||||
# Log test output for debugging failures
|
||||
if reward == 0.0:
|
||||
output_preview = output[-500:] if output else "(no output)"
|
||||
logger.info(
|
||||
"Task %s: FAIL (exit_code=%d)\n%s",
|
||||
task_name, exit_code, output_preview,
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# =========================================================================
|
||||
# Evaluate -- main entry point for the eval subcommand
|
||||
# =========================================================================
|
||||
|
||||
async def _eval_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Wrap rollout_and_score_eval with a per-task wall-clock timeout.
|
||||
|
||||
If the task exceeds task_timeout seconds, it's automatically scored
|
||||
as FAIL. This prevents any single task from hanging indefinitely.
|
||||
"""
|
||||
task_name = item.get("task_name", "unknown")
|
||||
category = item.get("category", "unknown")
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
timeout=self.config.task_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
elapsed = self.config.task_timeout
|
||||
tqdm.write(f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)")
|
||||
logger.error("Task %s: wall-clock timeout after %ds", task_name, elapsed)
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": f"timeout ({elapsed}s)",
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run Terminal-Bench 2.0 evaluation over all tasks.
|
||||
|
||||
This is the main entry point when invoked via:
|
||||
python environments/terminalbench2_env.py evaluate
|
||||
|
||||
Runs all tasks through rollout_and_score_eval() via asyncio.gather()
|
||||
(same pattern as GPQA and other Atropos eval envs). Each task is
|
||||
wrapped with a wall-clock timeout so hung tasks auto-fail.
|
||||
|
||||
Suppresses noisy Modal/terminal output (HERMES_QUIET) so the tqdm
|
||||
bar stays visible.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Route all logging through tqdm.write() so the progress bar stays
|
||||
# pinned at the bottom while log lines scroll above it.
|
||||
from tqdm import tqdm
|
||||
|
||||
class _TqdmHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
tqdm.write(self.format(record))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
))
|
||||
root = logging.getLogger()
|
||||
root.handlers = [handler] # Replace any existing handlers
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
# Silence noisy third-party loggers that flood the output
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request
|
||||
logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries
|
||||
logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment
|
||||
logging.getLogger("rex_image_builder").setLevel(logging.WARNING) # Image builds
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting Terminal-Bench 2.0 Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Total tasks: {len(self.all_eval_items)}")
|
||||
print(f" Max agent turns: {self.config.max_agent_turns}")
|
||||
print(f" Task timeout: {self.config.task_timeout}s")
|
||||
print(f" Terminal backend: {self.config.terminal_backend}")
|
||||
print(f" Tool thread pool: {self.config.tool_pool_size}")
|
||||
print(f" Terminal timeout: {self.config.terminal_timeout}s/cmd")
|
||||
print(f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)")
|
||||
print(f" Max concurrent tasks: {self.config.max_concurrent_tasks}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Semaphore to limit concurrent Modal sandbox creations.
|
||||
# Without this, all 86 tasks fire simultaneously, each creating a Modal
|
||||
# sandbox via asyncio.run() inside a thread pool worker. Modal's blocking
|
||||
# calls (App.lookup, etc.) deadlock when too many are created at once.
|
||||
semaphore = asyncio.Semaphore(self.config.max_concurrent_tasks)
|
||||
|
||||
async def _eval_with_semaphore(item):
|
||||
async with semaphore:
|
||||
return await self._eval_with_timeout(item)
|
||||
|
||||
# Fire all tasks with wall-clock timeout, track live accuracy on the bar
|
||||
total_tasks = len(self.all_eval_items)
|
||||
eval_tasks = [
|
||||
asyncio.ensure_future(_eval_with_semaphore(item))
|
||||
for item in self.all_eval_items
|
||||
]
|
||||
|
||||
results = []
|
||||
passed_count = 0
|
||||
pbar = tqdm(total=total_tasks, desc="Evaluating TB2", dynamic_ncols=True)
|
||||
try:
|
||||
for coro in asyncio.as_completed(eval_tasks):
|
||||
result = await coro
|
||||
results.append(result)
|
||||
if result and result.get("passed"):
|
||||
passed_count += 1
|
||||
done = len(results)
|
||||
pct = (passed_count / done * 100) if done else 0
|
||||
pbar.set_postfix_str(f"pass={passed_count}/{done} ({pct:.1f}%)")
|
||||
pbar.update(1)
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
pbar.close()
|
||||
print(f"\n\nInterrupted! Cleaning up {len(eval_tasks)} tasks...")
|
||||
# Cancel all pending tasks
|
||||
for task in eval_tasks:
|
||||
task.cancel()
|
||||
# Let cancellations propagate (finally blocks run cleanup_vm)
|
||||
await asyncio.gather(*eval_tasks, return_exceptions=True)
|
||||
# Belt-and-suspenders: clean up any remaining sandboxes
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
print("All sandboxes cleaned up.")
|
||||
return
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Filter out None results (shouldn't happen, but be safe)
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return
|
||||
|
||||
# ---- Compute metrics ----
|
||||
total = len(valid_results)
|
||||
passed = sum(1 for r in valid_results if r.get("passed"))
|
||||
overall_pass_rate = passed / total if total > 0 else 0.0
|
||||
|
||||
# Per-category breakdown
|
||||
cat_results: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in valid_results:
|
||||
cat_results[r.get("category", "unknown")].append(r)
|
||||
|
||||
# Build metrics dict
|
||||
eval_metrics = {
|
||||
"eval/pass_rate": overall_pass_rate,
|
||||
"eval/total_tasks": total,
|
||||
"eval/passed_tasks": passed,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
}
|
||||
|
||||
# Per-category metrics
|
||||
for category, cat_items in sorted(cat_results.items()):
|
||||
cat_passed = sum(1 for r in cat_items if r.get("passed"))
|
||||
cat_total = len(cat_items)
|
||||
cat_pass_rate = cat_passed / cat_total if cat_total > 0 else 0.0
|
||||
cat_key = category.replace(" ", "_").replace("-", "_").lower()
|
||||
eval_metrics[f"eval/pass_rate_{cat_key}"] = cat_pass_rate
|
||||
|
||||
# Store metrics for wandb_log
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# ---- Print summary ----
|
||||
print(f"\n{'='*60}")
|
||||
print("Terminal-Bench 2.0 Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(f"Overall Pass Rate: {overall_pass_rate:.4f} ({passed}/{total})")
|
||||
print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
|
||||
|
||||
print("\nCategory Breakdown:")
|
||||
for category, cat_items in sorted(cat_results.items()):
|
||||
cat_passed = sum(1 for r in cat_items if r.get("passed"))
|
||||
cat_total = len(cat_items)
|
||||
cat_rate = cat_passed / cat_total if cat_total > 0 else 0.0
|
||||
print(f" {category}: {cat_rate:.1%} ({cat_passed}/{cat_total})")
|
||||
|
||||
# Print individual task results
|
||||
print("\nTask Results:")
|
||||
for r in sorted(valid_results, key=lambda x: x.get("task_name", "")):
|
||||
status = "PASS" if r.get("passed") else "FAIL"
|
||||
turns = r.get("turns_used", "?")
|
||||
error = r.get("error", "")
|
||||
extra = f" (error: {error})" if error else ""
|
||||
print(f" [{status}] {r['task_name']} (turns={turns}){extra}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Build sample records for evaluate_log (includes full conversations)
|
||||
samples = [
|
||||
{
|
||||
"task_name": r.get("task_name"),
|
||||
"category": r.get("category"),
|
||||
"passed": r.get("passed"),
|
||||
"reward": r.get("reward"),
|
||||
"turns_used": r.get("turns_used"),
|
||||
"error": r.get("error"),
|
||||
"messages": r.get("messages"),
|
||||
}
|
||||
for r in valid_results
|
||||
]
|
||||
|
||||
# Log evaluation results
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
"terminal_backend": self.config.terminal_backend,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error logging evaluation results: {e}")
|
||||
|
||||
# Close streaming file
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
print(f" Live results saved to: {self._streaming_path}")
|
||||
|
||||
# Kill all remaining sandboxes. Timed-out tasks leave orphaned thread
|
||||
# pool workers still executing commands -- cleanup_all stops them.
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
print("\nCleaning up all sandboxes...")
|
||||
cleanup_all_environments()
|
||||
|
||||
# Shut down the tool thread pool so orphaned workers from timed-out
|
||||
# tasks are killed immediately instead of retrying against dead
|
||||
# sandboxes and spamming the console with TimeoutError warnings.
|
||||
from environments.agent_loop import _tool_executor
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
print("Done.")
|
||||
|
||||
# =========================================================================
|
||||
# Wandb logging
|
||||
# =========================================================================
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log TB2-specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Add stored eval metrics
|
||||
for metric_name, metric_value in self.eval_metrics:
|
||||
wandb_metrics[metric_name] = metric_value
|
||||
self.eval_metrics = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TerminalBench2EvalEnv.cli()
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Built-in gateway hooks that are always registered."""
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Built-in boot-md hook — run ~/.hermes/BOOT.md on gateway startup.
|
||||
|
||||
This hook is always registered. It silently skips if no BOOT.md exists.
|
||||
To activate, create ``~/.hermes/BOOT.md`` with instructions for the
|
||||
agent to execute on every gateway restart.
|
||||
|
||||
Example BOOT.md::
|
||||
|
||||
# Startup Checklist
|
||||
|
||||
1. Check if any cron jobs failed overnight
|
||||
2. Send a status update to Discord #general
|
||||
3. If there are errors in /opt/app/deploy.log, summarize them
|
||||
|
||||
The agent runs in a background thread so it doesn't block gateway
|
||||
startup. If nothing needs attention, it replies with [SILENT] to
|
||||
suppress delivery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger("hooks.boot-md")
|
||||
|
||||
HERMES_HOME = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
BOOT_FILE = HERMES_HOME / "BOOT.md"
|
||||
|
||||
|
||||
def _build_boot_prompt(content: str) -> str:
|
||||
"""Wrap BOOT.md content in a system-level instruction."""
|
||||
return (
|
||||
"You are running a startup boot checklist. Follow the BOOT.md "
|
||||
"instructions below exactly.\n\n"
|
||||
"---\n"
|
||||
f"{content}\n"
|
||||
"---\n\n"
|
||||
"Execute each instruction. If you need to send a message to a "
|
||||
"platform, use the send_message tool.\n"
|
||||
"If nothing needs attention and there is nothing to report, "
|
||||
"reply with ONLY: [SILENT]"
|
||||
)
|
||||
|
||||
|
||||
def _run_boot_agent(content: str) -> None:
|
||||
"""Spawn a one-shot agent session to execute the boot instructions."""
|
||||
try:
|
||||
from run_agent import AIAgent
|
||||
|
||||
prompt = _build_boot_prompt(content)
|
||||
agent = AIAgent(
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
max_iterations=20,
|
||||
)
|
||||
result = agent.run_conversation(prompt)
|
||||
response = result.get("final_response", "")
|
||||
if response and "[SILENT]" not in response:
|
||||
logger.info("boot-md completed: %s", response[:200])
|
||||
else:
|
||||
logger.info("boot-md completed (nothing to report)")
|
||||
except Exception as e:
|
||||
logger.error("boot-md agent failed: %s", e)
|
||||
|
||||
|
||||
async def handle(event_type: str, context: dict) -> None:
|
||||
"""Gateway startup handler — run BOOT.md if it exists."""
|
||||
if not BOOT_FILE.exists():
|
||||
return
|
||||
|
||||
content = BOOT_FILE.read_text(encoding="utf-8").strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
logger.info("Running BOOT.md (%d chars)", len(content))
|
||||
|
||||
# Run in a background thread so we don't block gateway startup.
|
||||
thread = threading.Thread(
|
||||
target=_run_boot_agent,
|
||||
args=(content,),
|
||||
name="boot-md",
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
+42
-43
@@ -647,14 +647,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.SLACK] = PlatformConfig()
|
||||
config.platforms[Platform.SLACK].enabled = True
|
||||
config.platforms[Platform.SLACK].token = slack_token
|
||||
# Home channel
|
||||
slack_home = os.getenv("SLACK_HOME_CHANNEL")
|
||||
if slack_home:
|
||||
config.platforms[Platform.SLACK].home_channel = HomeChannel(
|
||||
platform=Platform.SLACK,
|
||||
chat_id=slack_home,
|
||||
name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""),
|
||||
)
|
||||
slack_home = os.getenv("SLACK_HOME_CHANNEL")
|
||||
if slack_home and Platform.SLACK in config.platforms:
|
||||
config.platforms[Platform.SLACK].home_channel = HomeChannel(
|
||||
platform=Platform.SLACK,
|
||||
chat_id=slack_home,
|
||||
name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""),
|
||||
)
|
||||
|
||||
# Signal
|
||||
signal_url = os.getenv("SIGNAL_HTTP_URL")
|
||||
@@ -668,13 +667,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
"account": signal_account,
|
||||
"ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"),
|
||||
})
|
||||
signal_home = os.getenv("SIGNAL_HOME_CHANNEL")
|
||||
if signal_home:
|
||||
config.platforms[Platform.SIGNAL].home_channel = HomeChannel(
|
||||
platform=Platform.SIGNAL,
|
||||
chat_id=signal_home,
|
||||
name=os.getenv("SIGNAL_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
signal_home = os.getenv("SIGNAL_HOME_CHANNEL")
|
||||
if signal_home and Platform.SIGNAL in config.platforms:
|
||||
config.platforms[Platform.SIGNAL].home_channel = HomeChannel(
|
||||
platform=Platform.SIGNAL,
|
||||
chat_id=signal_home,
|
||||
name=os.getenv("SIGNAL_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Mattermost
|
||||
mattermost_token = os.getenv("MATTERMOST_TOKEN")
|
||||
@@ -687,13 +686,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.MATTERMOST].enabled = True
|
||||
config.platforms[Platform.MATTERMOST].token = mattermost_token
|
||||
config.platforms[Platform.MATTERMOST].extra["url"] = mattermost_url
|
||||
mattermost_home = os.getenv("MATTERMOST_HOME_CHANNEL")
|
||||
if mattermost_home:
|
||||
config.platforms[Platform.MATTERMOST].home_channel = HomeChannel(
|
||||
platform=Platform.MATTERMOST,
|
||||
chat_id=mattermost_home,
|
||||
name=os.getenv("MATTERMOST_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
mattermost_home = os.getenv("MATTERMOST_HOME_CHANNEL")
|
||||
if mattermost_home and Platform.MATTERMOST in config.platforms:
|
||||
config.platforms[Platform.MATTERMOST].home_channel = HomeChannel(
|
||||
platform=Platform.MATTERMOST,
|
||||
chat_id=mattermost_home,
|
||||
name=os.getenv("MATTERMOST_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Matrix
|
||||
matrix_token = os.getenv("MATRIX_ACCESS_TOKEN")
|
||||
@@ -715,13 +714,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.MATRIX].extra["password"] = matrix_password
|
||||
matrix_e2ee = os.getenv("MATRIX_ENCRYPTION", "").lower() in ("true", "1", "yes")
|
||||
config.platforms[Platform.MATRIX].extra["encryption"] = matrix_e2ee
|
||||
matrix_home = os.getenv("MATRIX_HOME_ROOM")
|
||||
if matrix_home:
|
||||
config.platforms[Platform.MATRIX].home_channel = HomeChannel(
|
||||
platform=Platform.MATRIX,
|
||||
chat_id=matrix_home,
|
||||
name=os.getenv("MATRIX_HOME_ROOM_NAME", "Home"),
|
||||
)
|
||||
matrix_home = os.getenv("MATRIX_HOME_ROOM")
|
||||
if matrix_home and Platform.MATRIX in config.platforms:
|
||||
config.platforms[Platform.MATRIX].home_channel = HomeChannel(
|
||||
platform=Platform.MATRIX,
|
||||
chat_id=matrix_home,
|
||||
name=os.getenv("MATRIX_HOME_ROOM_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Home Assistant
|
||||
hass_token = os.getenv("HASS_TOKEN")
|
||||
@@ -748,13 +747,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
"imap_host": email_imap,
|
||||
"smtp_host": email_smtp,
|
||||
})
|
||||
email_home = os.getenv("EMAIL_HOME_ADDRESS")
|
||||
if email_home:
|
||||
config.platforms[Platform.EMAIL].home_channel = HomeChannel(
|
||||
platform=Platform.EMAIL,
|
||||
chat_id=email_home,
|
||||
name=os.getenv("EMAIL_HOME_ADDRESS_NAME", "Home"),
|
||||
)
|
||||
email_home = os.getenv("EMAIL_HOME_ADDRESS")
|
||||
if email_home and Platform.EMAIL in config.platforms:
|
||||
config.platforms[Platform.EMAIL].home_channel = HomeChannel(
|
||||
platform=Platform.EMAIL,
|
||||
chat_id=email_home,
|
||||
name=os.getenv("EMAIL_HOME_ADDRESS_NAME", "Home"),
|
||||
)
|
||||
|
||||
# SMS (Twilio)
|
||||
twilio_sid = os.getenv("TWILIO_ACCOUNT_SID")
|
||||
@@ -763,13 +762,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.SMS] = PlatformConfig()
|
||||
config.platforms[Platform.SMS].enabled = True
|
||||
config.platforms[Platform.SMS].api_key = os.getenv("TWILIO_AUTH_TOKEN", "")
|
||||
sms_home = os.getenv("SMS_HOME_CHANNEL")
|
||||
if sms_home:
|
||||
config.platforms[Platform.SMS].home_channel = HomeChannel(
|
||||
platform=Platform.SMS,
|
||||
chat_id=sms_home,
|
||||
name=os.getenv("SMS_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
sms_home = os.getenv("SMS_HOME_CHANNEL")
|
||||
if sms_home and Platform.SMS in config.platforms:
|
||||
config.platforms[Platform.SMS].home_channel = HomeChannel(
|
||||
platform=Platform.SMS,
|
||||
chat_id=sms_home,
|
||||
name=os.getenv("SMS_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# API Server
|
||||
api_server_enabled = os.getenv("API_SERVER_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
|
||||
@@ -51,14 +51,33 @@ class HookRegistry:
|
||||
"""Return metadata about all loaded hooks."""
|
||||
return list(self._loaded_hooks)
|
||||
|
||||
def _register_builtin_hooks(self) -> None:
|
||||
"""Register built-in hooks that are always active."""
|
||||
try:
|
||||
from gateway.builtin_hooks.boot_md import handle as boot_md_handle
|
||||
|
||||
self._handlers.setdefault("gateway:startup", []).append(boot_md_handle)
|
||||
self._loaded_hooks.append({
|
||||
"name": "boot-md",
|
||||
"description": "Run ~/.hermes/BOOT.md on gateway startup",
|
||||
"events": ["gateway:startup"],
|
||||
"path": "(builtin)",
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"[hooks] Could not load built-in boot-md hook: {e}", flush=True)
|
||||
|
||||
def discover_and_load(self) -> None:
|
||||
"""
|
||||
Scan the hooks directory for hook directories and load their handlers.
|
||||
|
||||
Also registers built-in hooks that are always active.
|
||||
|
||||
Each hook directory must contain:
|
||||
- HOOK.yaml with at least 'name' and 'events' keys
|
||||
- handler.py with a top-level 'handle' function (sync or async)
|
||||
"""
|
||||
self._register_builtin_hooks()
|
||||
|
||||
if not HOOKS_DIR.exists():
|
||||
return
|
||||
|
||||
|
||||
@@ -1261,6 +1261,17 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
self._app.router.add_post("/api/jobs/{job_id}/resume", self._handle_resume_job)
|
||||
self._app.router.add_post("/api/jobs/{job_id}/run", self._handle_run_job)
|
||||
|
||||
# Port conflict detection — fail fast if port is already in use
|
||||
import socket as _socket
|
||||
try:
|
||||
with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as _s:
|
||||
_s.settimeout(1)
|
||||
_s.connect(('127.0.0.1', self._port))
|
||||
logger.error('[%s] Port %d already in use. Set a different port in config.yaml: platforms.api_server.port', self.name, self._port)
|
||||
return False
|
||||
except (ConnectionRefusedError, OSError):
|
||||
pass # port is free
|
||||
|
||||
self._runner = web.AppRunner(self._app)
|
||||
await self._runner.setup()
|
||||
self._site = web.TCPSite(self._runner, self._host, self._port)
|
||||
|
||||
@@ -1005,7 +1005,7 @@ class BasePlatformAdapter(ABC):
|
||||
# simultaneous messages. Queue them without interrupting the active run,
|
||||
# then process them immediately after the current task finishes.
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
print(f"[{self.name}] 🖼️ Queuing photo follow-up for session {session_key} without interrupt")
|
||||
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
|
||||
existing = self._pending_messages.get(session_key)
|
||||
if existing and existing.message_type == MessageType.PHOTO:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
@@ -1020,7 +1020,7 @@ class BasePlatformAdapter(ABC):
|
||||
return # Don't interrupt now - will run after current task completes
|
||||
|
||||
# Default behavior for non-photo follow-ups: interrupt the running agent
|
||||
print(f"[{self.name}] ⚡ New message while session {session_key} is active - triggering interrupt")
|
||||
logger.debug("[%s] New message while session %s is active — triggering interrupt", self.name, session_key)
|
||||
self._pending_messages[session_key] = event
|
||||
# Signal the interrupt (the processing task checks this)
|
||||
self._active_sessions[session_key].set()
|
||||
@@ -1206,9 +1206,9 @@ class BasePlatformAdapter(ABC):
|
||||
)
|
||||
|
||||
if not media_result.success:
|
||||
print(f"[{self.name}] Failed to send media ({ext}): {media_result.error}")
|
||||
logger.warning("[%s] Failed to send media (%s): %s", self.name, ext, media_result.error)
|
||||
except Exception as media_err:
|
||||
print(f"[{self.name}] Error sending media: {media_err}")
|
||||
logger.warning("[%s] Error sending media: %s", self.name, media_err)
|
||||
|
||||
# Send auto-detected local files as native attachments
|
||||
for file_path in local_files:
|
||||
@@ -1240,7 +1240,7 @@ class BasePlatformAdapter(ABC):
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
pending_event = self._pending_messages.pop(session_key)
|
||||
print(f"[{self.name}] 📨 Processing queued message from interrupt")
|
||||
logger.debug("[%s] Processing queued message from interrupt", self.name)
|
||||
# Clean up current session before processing pending
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
@@ -1254,9 +1254,7 @@ class BasePlatformAdapter(ABC):
|
||||
return # Already cleaned up
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error handling message: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True)
|
||||
# Send the error to the user so they aren't left with radio silence
|
||||
try:
|
||||
error_type = type(e).__name__
|
||||
|
||||
@@ -486,6 +486,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
try:
|
||||
# Acquire scoped lock to prevent duplicate bot token usage
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._token_lock_identity = self.config.token
|
||||
acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'})
|
||||
if not acquired:
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error('discord_token_lock', message, retryable=False)
|
||||
return False
|
||||
|
||||
# Set up intents -- members intent needed for username-to-ID resolution
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
@@ -638,6 +649,16 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._running = False
|
||||
self._client = None
|
||||
self._ready_event.clear()
|
||||
|
||||
# Release the token lock
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('discord-bot-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
async def send(
|
||||
@@ -1429,15 +1450,23 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
command_text: str,
|
||||
followup_msg: str | None = None,
|
||||
) -> None:
|
||||
"""Common handler for simple slash commands that dispatch a command string."""
|
||||
"""Common handler for simple slash commands that dispatch a command string.
|
||||
|
||||
Defers the interaction (shows "thinking..."), dispatches the command,
|
||||
then cleans up the deferred response. If *followup_msg* is provided
|
||||
the "thinking..." indicator is replaced with that text; otherwise it
|
||||
is deleted so the channel isn't cluttered.
|
||||
"""
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, command_text)
|
||||
await self.handle_message(event)
|
||||
if followup_msg:
|
||||
try:
|
||||
await interaction.followup.send(followup_msg, ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
try:
|
||||
if followup_msg:
|
||||
await interaction.edit_original_response(content=followup_msg)
|
||||
else:
|
||||
await interaction.delete_original_response()
|
||||
except Exception as e:
|
||||
logger.debug("Discord interaction cleanup failed: %s", e)
|
||||
|
||||
def _register_slash_commands(self) -> None:
|
||||
"""Register Discord slash commands on the command tree."""
|
||||
@@ -1462,9 +1491,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
@tree.command(name="reasoning", description="Show or change reasoning effort")
|
||||
@discord.app_commands.describe(effort="Reasoning effort: xhigh, high, medium, low, minimal, or none.")
|
||||
async def slash_reasoning(interaction: discord.Interaction, effort: str = ""):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/reasoning {effort}".strip())
|
||||
await self.handle_message(event)
|
||||
await self._run_simple_slash(interaction, f"/reasoning {effort}".strip())
|
||||
|
||||
@tree.command(name="personality", description="Set a personality")
|
||||
@discord.app_commands.describe(name="Personality name. Leave empty to list available.")
|
||||
@@ -1537,9 +1564,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
discord.app_commands.Choice(name="status — show current mode", value="status"),
|
||||
])
|
||||
async def slash_voice(interaction: discord.Interaction, mode: str = ""):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/voice {mode}".strip())
|
||||
await self.handle_message(event)
|
||||
await self._run_simple_slash(interaction, f"/voice {mode}".strip())
|
||||
|
||||
@tree.command(name="update", description="Update Hermes Agent to the latest version")
|
||||
async def slash_update(interaction: discord.Interaction):
|
||||
|
||||
+68
-55
@@ -337,60 +337,63 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
results = []
|
||||
try:
|
||||
imap = imaplib.IMAP4_SSL(self._imap_host, self._imap_port, timeout=30)
|
||||
imap.login(self._address, self._password)
|
||||
imap.select("INBOX")
|
||||
try:
|
||||
imap.login(self._address, self._password)
|
||||
imap.select("INBOX")
|
||||
|
||||
status, data = imap.uid("search", None, "UNSEEN")
|
||||
if status != "OK" or not data or not data[0]:
|
||||
imap.logout()
|
||||
return results
|
||||
status, data = imap.uid("search", None, "UNSEEN")
|
||||
if status != "OK" or not data or not data[0]:
|
||||
return results
|
||||
|
||||
for uid in data[0].split():
|
||||
if uid in self._seen_uids:
|
||||
continue
|
||||
self._seen_uids.add(uid)
|
||||
# Trim periodically to prevent unbounded memory growth
|
||||
if len(self._seen_uids) > self._seen_uids_max:
|
||||
self._trim_seen_uids()
|
||||
for uid in data[0].split():
|
||||
if uid in self._seen_uids:
|
||||
continue
|
||||
self._seen_uids.add(uid)
|
||||
# Trim periodically to prevent unbounded memory growth
|
||||
if len(self._seen_uids) > self._seen_uids_max:
|
||||
self._trim_seen_uids()
|
||||
|
||||
status, msg_data = imap.uid("fetch", uid, "(RFC822)")
|
||||
if status != "OK":
|
||||
continue
|
||||
status, msg_data = imap.uid("fetch", uid, "(RFC822)")
|
||||
if status != "OK":
|
||||
continue
|
||||
|
||||
raw_email = msg_data[0][1]
|
||||
msg = email_lib.message_from_bytes(raw_email)
|
||||
raw_email = msg_data[0][1]
|
||||
msg = email_lib.message_from_bytes(raw_email)
|
||||
|
||||
sender_raw = msg.get("From", "")
|
||||
sender_addr = _extract_email_address(sender_raw)
|
||||
sender_name = _decode_header_value(sender_raw)
|
||||
# Remove email from name if present
|
||||
if "<" in sender_name:
|
||||
sender_name = sender_name.split("<")[0].strip().strip('"')
|
||||
sender_raw = msg.get("From", "")
|
||||
sender_addr = _extract_email_address(sender_raw)
|
||||
sender_name = _decode_header_value(sender_raw)
|
||||
# Remove email from name if present
|
||||
if "<" in sender_name:
|
||||
sender_name = sender_name.split("<")[0].strip().strip('"')
|
||||
|
||||
subject = _decode_header_value(msg.get("Subject", "(no subject)"))
|
||||
message_id = msg.get("Message-ID", "")
|
||||
in_reply_to = msg.get("In-Reply-To", "")
|
||||
# Skip automated/noreply senders before any processing
|
||||
msg_headers = dict(msg.items())
|
||||
if _is_automated_sender(sender_addr, msg_headers):
|
||||
logger.debug("[Email] Skipping automated sender: %s", sender_addr)
|
||||
continue
|
||||
body = _extract_text_body(msg)
|
||||
attachments = _extract_attachments(msg, skip_attachments=self._skip_attachments)
|
||||
subject = _decode_header_value(msg.get("Subject", "(no subject)"))
|
||||
message_id = msg.get("Message-ID", "")
|
||||
in_reply_to = msg.get("In-Reply-To", "")
|
||||
# Skip automated/noreply senders before any processing
|
||||
msg_headers = dict(msg.items())
|
||||
if _is_automated_sender(sender_addr, msg_headers):
|
||||
logger.debug("[Email] Skipping automated sender: %s", sender_addr)
|
||||
continue
|
||||
body = _extract_text_body(msg)
|
||||
attachments = _extract_attachments(msg, skip_attachments=self._skip_attachments)
|
||||
|
||||
results.append({
|
||||
"uid": uid,
|
||||
"sender_addr": sender_addr,
|
||||
"sender_name": sender_name,
|
||||
"subject": subject,
|
||||
"message_id": message_id,
|
||||
"in_reply_to": in_reply_to,
|
||||
"body": body,
|
||||
"attachments": attachments,
|
||||
"date": msg.get("Date", ""),
|
||||
})
|
||||
|
||||
imap.logout()
|
||||
results.append({
|
||||
"uid": uid,
|
||||
"sender_addr": sender_addr,
|
||||
"sender_name": sender_name,
|
||||
"subject": subject,
|
||||
"message_id": message_id,
|
||||
"in_reply_to": in_reply_to,
|
||||
"body": body,
|
||||
"attachments": attachments,
|
||||
"date": msg.get("Date", ""),
|
||||
})
|
||||
finally:
|
||||
try:
|
||||
imap.logout()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("[Email] IMAP fetch error: %s", e)
|
||||
return results
|
||||
@@ -503,10 +506,15 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
msg.attach(MIMEText(body, "plain", "utf-8"))
|
||||
|
||||
smtp = smtplib.SMTP(self._smtp_host, self._smtp_port, timeout=30)
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self._address, self._password)
|
||||
smtp.send_message(msg)
|
||||
smtp.quit()
|
||||
try:
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self._address, self._password)
|
||||
smtp.send_message(msg)
|
||||
finally:
|
||||
try:
|
||||
smtp.quit()
|
||||
except Exception:
|
||||
smtp.close()
|
||||
|
||||
logger.info("[Email] Sent reply to %s (subject: %s)", to_addr, subject)
|
||||
return msg_id
|
||||
@@ -590,10 +598,15 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
msg.attach(part)
|
||||
|
||||
smtp = smtplib.SMTP(self._smtp_host, self._smtp_port, timeout=30)
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self._address, self._password)
|
||||
smtp.send_message(msg)
|
||||
smtp.quit()
|
||||
try:
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self._address, self._password)
|
||||
smtp.send_message(msg)
|
||||
finally:
|
||||
try:
|
||||
smtp.quit()
|
||||
except Exception:
|
||||
smtp.close()
|
||||
|
||||
return msg_id
|
||||
|
||||
|
||||
@@ -603,9 +603,19 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
# For DMs, user_id is sufficient. For channels, check for @mention.
|
||||
message_text = post.get("message", "")
|
||||
|
||||
# Mention-only mode: skip channel messages that don't @mention the bot.
|
||||
# DMs (type "D") are always processed.
|
||||
# Mention-gating for non-DM channels.
|
||||
# Config (env vars):
|
||||
# MATTERMOST_REQUIRE_MENTION: Require @mention in channels (default: true)
|
||||
# MATTERMOST_FREE_RESPONSE_CHANNELS: Channel IDs where bot responds without mention
|
||||
if channel_type_raw != "D":
|
||||
require_mention = os.getenv(
|
||||
"MATTERMOST_REQUIRE_MENTION", "true"
|
||||
).lower() not in ("false", "0", "no")
|
||||
|
||||
free_channels_raw = os.getenv("MATTERMOST_FREE_RESPONSE_CHANNELS", "")
|
||||
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
|
||||
is_free_channel = channel_id in free_channels
|
||||
|
||||
mention_patterns = [
|
||||
f"@{self._bot_username}",
|
||||
f"@{self._bot_user_id}",
|
||||
@@ -614,13 +624,21 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
pattern.lower() in message_text.lower()
|
||||
for pattern in mention_patterns
|
||||
)
|
||||
if not has_mention:
|
||||
|
||||
if require_mention and not is_free_channel and not has_mention:
|
||||
logger.debug(
|
||||
"Mattermost: skipping non-DM message without @mention (channel=%s)",
|
||||
channel_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Strip @mention from the message text so the agent sees clean input.
|
||||
if has_mention:
|
||||
for pattern in mention_patterns:
|
||||
message_text = re.sub(
|
||||
re.escape(pattern), "", message_text, flags=re.IGNORECASE
|
||||
).strip()
|
||||
|
||||
# Resolve sender info.
|
||||
sender_id = post.get("user_id", "")
|
||||
sender_name = data.get("sender_name", "").lstrip("@") or sender_id
|
||||
|
||||
@@ -22,7 +22,7 @@ import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -184,6 +184,8 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self._recent_sent_timestamps: set = set()
|
||||
self._max_recent_timestamps = 50
|
||||
|
||||
self._phone_lock_identity: Optional[str] = None
|
||||
|
||||
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url, _redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled")
|
||||
@@ -198,6 +200,29 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
logger.error("Signal: SIGNAL_HTTP_URL and SIGNAL_ACCOUNT are required")
|
||||
return False
|
||||
|
||||
# Acquire scoped lock to prevent duplicate Signal listeners for the same phone
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._phone_lock_identity = self.account
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"signal-phone",
|
||||
self._phone_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Signal account"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Signal listener."
|
||||
)
|
||||
logger.error("Signal: %s", message)
|
||||
self._set_fatal_error("signal_phone_lock", message, retryable=False)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e)
|
||||
|
||||
self.client = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
# Health check — verify signal-cli daemon is reachable
|
||||
@@ -245,6 +270,14 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
await self.client.aclose()
|
||||
self.client = None
|
||||
|
||||
if self._phone_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("signal-phone", self._phone_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True)
|
||||
self._phone_lock_identity = None
|
||||
|
||||
logger.info("Signal: disconnected")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -253,7 +286,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _sse_listener(self) -> None:
|
||||
"""Listen for SSE events from signal-cli daemon."""
|
||||
url = f"{self.http_url}/api/v1/events?account={self.account}"
|
||||
url = f"{self.http_url}/api/v1/events?account={quote(self.account, safe='')}"
|
||||
backoff = SSE_RETRY_DELAY_INITIAL
|
||||
|
||||
while self._running:
|
||||
@@ -521,7 +554,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
"""Fetch an attachment via JSON-RPC and cache it. Returns (path, ext)."""
|
||||
result = await self._rpc("getAttachment", {
|
||||
"account": self.account,
|
||||
"attachmentId": attachment_id,
|
||||
"id": attachment_id,
|
||||
})
|
||||
|
||||
if not result:
|
||||
|
||||
@@ -93,6 +93,17 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
try:
|
||||
# Acquire scoped lock to prevent duplicate app token usage
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._token_lock_identity = app_token
|
||||
acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'})
|
||||
if not acquired:
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error('slack_token_lock', message, retryable=False)
|
||||
return False
|
||||
|
||||
self._app = AsyncApp(token=bot_token)
|
||||
|
||||
# Get our own bot user ID for mention detection
|
||||
@@ -138,6 +149,16 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True)
|
||||
self._running = False
|
||||
|
||||
# Release the token lock (use stored identity, not re-read env)
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('slack-app-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("[Slack] Disconnected")
|
||||
|
||||
async def send(
|
||||
|
||||
@@ -118,6 +118,17 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
app.router.add_get("/health", self._handle_health)
|
||||
app.router.add_post("/webhooks/{route_name}", self._handle_webhook)
|
||||
|
||||
# Port conflict detection — fail fast if port is already in use
|
||||
import socket as _socket
|
||||
try:
|
||||
with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as _s:
|
||||
_s.settimeout(1)
|
||||
_s.connect(('127.0.0.1', self._port))
|
||||
logger.error('[webhook] Port %d already in use. Set a different port in config.yaml: platforms.webhook.port', self._port)
|
||||
return False
|
||||
except (ConnectionRefusedError, OSError):
|
||||
pass # port is free
|
||||
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, self._host, self._port)
|
||||
|
||||
+146
-103
@@ -142,6 +142,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._bridge_log_fh = None
|
||||
self._bridge_log: Optional[Path] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
self._session_lock_identity: Optional[str] = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
@@ -160,6 +162,29 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
|
||||
logger.info("[%s] Bridge found at %s", self.name, bridge_path)
|
||||
|
||||
# Acquire scoped lock to prevent duplicate sessions
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._session_lock_identity = str(self._session_path)
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"whatsapp-session",
|
||||
self._session_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this WhatsApp session"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second WhatsApp bridge."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("whatsapp_session_lock", message, retryable=False)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e)
|
||||
|
||||
# Auto-install npm dependencies if node_modules doesn't exist
|
||||
bridge_dir = bridge_path.parent
|
||||
if not (bridge_dir / "node_modules").exists():
|
||||
@@ -200,6 +225,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Using existing bridge (status: {bridge_status})")
|
||||
self._mark_connected()
|
||||
self._bridge_process = None # Not managed by us
|
||||
self._http_session = aiohttp.ClientSession()
|
||||
self._poll_task = asyncio.create_task(self._poll_messages())
|
||||
return True
|
||||
else:
|
||||
@@ -305,6 +331,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Bridge log: {self._bridge_log}")
|
||||
print(f"[{self.name}] If session expired, re-pair: hermes whatsapp")
|
||||
|
||||
# Create a persistent HTTP session for all bridge communication
|
||||
self._http_session = aiohttp.ClientSession()
|
||||
|
||||
# Start message polling task
|
||||
self._poll_task = asyncio.create_task(self._poll_messages())
|
||||
|
||||
@@ -313,6 +342,12 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
if self._session_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("whatsapp-session", self._session_lock_identity)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
@@ -370,10 +405,32 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
# Bridge was not started by us, don't kill it
|
||||
print(f"[{self.name}] Disconnecting (external bridge left running)")
|
||||
|
||||
|
||||
# Cancel the poll task explicitly
|
||||
if self._poll_task and not self._poll_task.done():
|
||||
self._poll_task.cancel()
|
||||
try:
|
||||
await self._poll_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
self._poll_task = None
|
||||
|
||||
# Close the persistent HTTP session
|
||||
if self._http_session and not self._http_session.closed:
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
|
||||
if self._session_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("whatsapp-session", self._session_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True)
|
||||
|
||||
self._mark_disconnected()
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
self._session_lock_identity = None
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
async def send(
|
||||
@@ -384,7 +441,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> SendResult:
|
||||
"""Send a message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
if not self._running or not self._http_session:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
@@ -392,36 +449,29 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
payload = {
|
||||
"chatId": chat_id,
|
||||
"message": content,
|
||||
}
|
||||
if reply_to:
|
||||
payload["replyTo"] = reply_to
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"chatId": chat_id,
|
||||
"message": content,
|
||||
}
|
||||
if reply_to:
|
||||
payload["replyTo"] = reply_to
|
||||
|
||||
async with session.post(
|
||||
f"http://127.0.0.1:{self._bridge_port}/send",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
except ImportError:
|
||||
return SendResult(
|
||||
success=False,
|
||||
error="aiohttp not installed. Run: pip install aiohttp"
|
||||
)
|
||||
async with self._http_session.post(
|
||||
f"http://127.0.0.1:{self._bridge_port}/send",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@@ -432,28 +482,27 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
content: str,
|
||||
) -> SendResult:
|
||||
"""Edit a previously sent message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
if not self._running or not self._http_session:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
return SendResult(success=False, error=bridge_exit)
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://127.0.0.1:{self._bridge_port}/edit",
|
||||
json={
|
||||
"chatId": chat_id,
|
||||
"messageId": message_id,
|
||||
"message": content,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=15)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
async with self._http_session.post(
|
||||
f"http://127.0.0.1:{self._bridge_port}/edit",
|
||||
json={
|
||||
"chatId": chat_id,
|
||||
"messageId": message_id,
|
||||
"message": content,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=15)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@@ -466,7 +515,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
file_name: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send any media file via bridge /send-media endpoint."""
|
||||
if not self._running:
|
||||
if not self._running or not self._http_session:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
@@ -487,22 +536,21 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
if file_name:
|
||||
payload["fileName"] = file_name
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://127.0.0.1:{self._bridge_port}/send-media",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data,
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
async with self._http_session.post(
|
||||
f"http://127.0.0.1:{self._bridge_port}/send-media",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data,
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
@@ -560,45 +608,43 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""Send typing indicator via bridge."""
|
||||
if not self._running:
|
||||
if not self._running or not self._http_session:
|
||||
return
|
||||
if await self._check_managed_bridge_exit():
|
||||
return
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(
|
||||
f"http://127.0.0.1:{self._bridge_port}/typing",
|
||||
json={"chatId": chat_id},
|
||||
timeout=aiohttp.ClientTimeout(total=5)
|
||||
)
|
||||
|
||||
await self._http_session.post(
|
||||
f"http://127.0.0.1:{self._bridge_port}/typing",
|
||||
json={"chatId": chat_id},
|
||||
timeout=aiohttp.ClientTimeout(total=5)
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a WhatsApp chat."""
|
||||
if not self._running:
|
||||
if not self._running or not self._http_session:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
if await self._check_managed_bridge_exit():
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://127.0.0.1:{self._bridge_port}/chat/{chat_id}",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"name": data.get("name", chat_id),
|
||||
"type": "group" if data.get("isGroup") else "dm",
|
||||
"participants": data.get("participants", []),
|
||||
}
|
||||
|
||||
async with self._http_session.get(
|
||||
f"http://127.0.0.1:{self._bridge_port}/chat/{chat_id}",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"name": data.get("name", chat_id),
|
||||
"type": "group" if data.get("isGroup") else "dm",
|
||||
"participants": data.get("participants", []),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Could not get WhatsApp chat info for %s: %s", chat_id, e)
|
||||
|
||||
@@ -606,29 +652,26 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _poll_messages(self) -> None:
|
||||
"""Poll the bridge for incoming messages."""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, message polling disabled")
|
||||
return
|
||||
|
||||
import aiohttp
|
||||
|
||||
while self._running:
|
||||
if not self._http_session:
|
||||
break
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
print(f"[{self.name}] {bridge_exit}")
|
||||
break
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://127.0.0.1:{self._bridge_port}/messages",
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = await self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
async with self._http_session.get(
|
||||
f"http://127.0.0.1:{self._bridge_port}/messages",
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = await self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
|
||||
+20
-16
@@ -77,6 +77,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Resolve Hermes home directory (respects HERMES_HOME override)
|
||||
from hermes_constants import get_hermes_home
|
||||
from utils import atomic_yaml_write
|
||||
_hermes_home = get_hermes_home()
|
||||
|
||||
# Load environment variables from ~/.hermes/.env first.
|
||||
@@ -918,11 +919,12 @@ class GatewayRunner:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _load_fallback_model() -> dict | None:
|
||||
"""Load fallback model config from config.yaml.
|
||||
def _load_fallback_model() -> list | dict | None:
|
||||
"""Load fallback provider chain from config.yaml.
|
||||
|
||||
Returns a dict with 'provider' and 'model' keys, or None if
|
||||
not configured / both fields empty.
|
||||
Returns a list of provider dicts (``fallback_providers``), a single
|
||||
dict (legacy ``fallback_model``), or None if not configured.
|
||||
AIAgent.__init__ normalizes both formats into a chain.
|
||||
"""
|
||||
try:
|
||||
import yaml as _y
|
||||
@@ -930,8 +932,8 @@ class GatewayRunner:
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path, encoding="utf-8") as _f:
|
||||
cfg = _y.safe_load(_f) or {}
|
||||
fb = cfg.get("fallback_model", {}) or {}
|
||||
if fb.get("provider") and fb.get("model"):
|
||||
fb = cfg.get("fallback_providers") or cfg.get("fallback_model") or None
|
||||
if fb:
|
||||
return fb
|
||||
except Exception:
|
||||
pass
|
||||
@@ -959,6 +961,13 @@ class GatewayRunner:
|
||||
"""
|
||||
logger.info("Starting Hermes Gateway...")
|
||||
logger.info("Session storage: %s", self.config.sessions_dir)
|
||||
try:
|
||||
from hermes_cli.profiles import get_active_profile_name
|
||||
_profile = get_active_profile_name()
|
||||
if _profile and _profile != "default":
|
||||
logger.info("Active profile: %s", _profile)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(gateway_state="starting", exit_reason=None)
|
||||
@@ -3088,8 +3097,7 @@ class GatewayRunner:
|
||||
if "agent" not in config or not isinstance(config.get("agent"), dict):
|
||||
config["agent"] = {}
|
||||
config["agent"]["system_prompt"] = ""
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
||||
atomic_yaml_write(config_path, config)
|
||||
except Exception as e:
|
||||
return f"⚠️ Failed to save personality change: {e}"
|
||||
self._ephemeral_system_prompt = ""
|
||||
@@ -3102,8 +3110,7 @@ class GatewayRunner:
|
||||
if "agent" not in config or not isinstance(config.get("agent"), dict):
|
||||
config["agent"] = {}
|
||||
config["agent"]["system_prompt"] = new_prompt
|
||||
with open(config_path, 'w', encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
||||
atomic_yaml_write(config_path, config)
|
||||
except Exception as e:
|
||||
return f"⚠️ Failed to save personality change: {e}"
|
||||
|
||||
@@ -3193,8 +3200,7 @@ class GatewayRunner:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
user_config[env_key] = chat_id
|
||||
with open(config_path, 'w', encoding="utf-8") as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False)
|
||||
atomic_yaml_write(config_path, user_config)
|
||||
# Also set in the current environment so it takes effect immediately
|
||||
os.environ[env_key] = str(chat_id)
|
||||
except Exception as e:
|
||||
@@ -3862,8 +3868,7 @@ class GatewayRunner:
|
||||
current[k] = {}
|
||||
current = current[k]
|
||||
current[keys[-1]] = value
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
atomic_yaml_write(config_path, user_config)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to save config key %s: %s", key_path, e)
|
||||
@@ -3971,8 +3976,7 @@ class GatewayRunner:
|
||||
if "display" not in user_config or not isinstance(user_config.get("display"), dict):
|
||||
user_config["display"] = {}
|
||||
user_config["display"]["tool_progress"] = new_mode
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
atomic_yaml_write(config_path, user_config)
|
||||
return f"{descriptions[new_mode]}\n_(saved to config — takes effect on next message)_"
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save tool_progress mode: %s", e)
|
||||
|
||||
+2
-1
@@ -2021,7 +2021,8 @@ def _login_openai_codex(args, pconfig: ProviderConfig) -> None:
|
||||
config_path = _update_config_for_provider("openai-codex", creds.get("base_url", DEFAULT_CODEX_BASE_URL))
|
||||
print()
|
||||
print("Login successful!")
|
||||
print(" Auth state: ~/.hermes/auth.json")
|
||||
from hermes_constants import display_hermes_home as _dhh
|
||||
print(f" Auth state: {_dhh()}/auth.json")
|
||||
print(f" Config updated: {config_path} (model.provider=openai-codex)")
|
||||
|
||||
|
||||
|
||||
+25
-2
@@ -258,7 +258,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
get_toolset_for_tool: Callable to map tool name -> toolset name.
|
||||
context_length: Model's context window size in tokens.
|
||||
"""
|
||||
from model_tools import check_tool_availability
|
||||
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
|
||||
if get_toolset_for_tool is None:
|
||||
from model_tools import get_toolset_for_tool
|
||||
|
||||
@@ -267,8 +267,18 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
|
||||
_, unavailable_toolsets = check_tool_availability(quiet=True)
|
||||
disabled_tools = set()
|
||||
# Tools whose toolset has a check_fn are lazy-initialized (e.g. honcho,
|
||||
# homeassistant) — they show as unavailable at banner time because the
|
||||
# check hasn't run yet, but they aren't misconfigured.
|
||||
lazy_tools = set()
|
||||
for item in unavailable_toolsets:
|
||||
disabled_tools.update(item.get("tools", []))
|
||||
toolset_name = item.get("name", "")
|
||||
ts_req = TOOLSET_REQUIREMENTS.get(toolset_name, {})
|
||||
tools_in_ts = item.get("tools", [])
|
||||
if ts_req.get("check_fn"):
|
||||
lazy_tools.update(tools_in_ts)
|
||||
else:
|
||||
disabled_tools.update(tools_in_ts)
|
||||
|
||||
layout_table = Table.grid(padding=(0, 2))
|
||||
layout_table.add_column("left", justify="center")
|
||||
@@ -328,6 +338,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
for name in sorted(tool_names):
|
||||
if name in disabled_tools:
|
||||
colored_names.append(f"[red]{name}[/]")
|
||||
elif name in lazy_tools:
|
||||
colored_names.append(f"[yellow]{name}[/]")
|
||||
else:
|
||||
colored_names.append(f"[{text}]{name}[/]")
|
||||
|
||||
@@ -347,6 +359,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
colored_names.append("[dim]...[/]")
|
||||
elif name in disabled_tools:
|
||||
colored_names.append(f"[red]{name}[/]")
|
||||
elif name in lazy_tools:
|
||||
colored_names.append(f"[yellow]{name}[/]")
|
||||
else:
|
||||
colored_names.append(f"[{text}]{name}[/]")
|
||||
tools_str = ", ".join(colored_names)
|
||||
@@ -403,6 +417,15 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
if mcp_connected:
|
||||
summary_parts.append(f"{mcp_connected} MCP servers")
|
||||
summary_parts.append("/help for commands")
|
||||
# Show active profile name when not 'default'
|
||||
try:
|
||||
from hermes_cli.profiles import get_active_profile_name
|
||||
_profile_name = get_active_profile_name()
|
||||
if _profile_name and _profile_name != "default":
|
||||
right_lines.append(f"[bold {accent}]Profile:[/] [{text}]{_profile_name}[/]")
|
||||
except Exception:
|
||||
pass # Never break the banner over a profiles.py bug
|
||||
|
||||
right_lines.append(f"[dim {dim}]{' · '.join(summary_parts)}[/]")
|
||||
|
||||
# Update check — use prefetched result if available
|
||||
|
||||
@@ -12,6 +12,7 @@ import getpass
|
||||
|
||||
from hermes_cli.banner import cprint, _DIM, _RST
|
||||
from hermes_cli.config import save_env_value_secure
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
|
||||
def clarify_callback(cli, question, choices):
|
||||
@@ -131,7 +132,8 @@ def prompt_for_secret(cli, var_name: str, prompt: str, metadata=None) -> dict:
|
||||
}
|
||||
|
||||
stored = save_env_value_secure(var_name, value)
|
||||
cprint(f"\n{_DIM} ✓ Stored secret in ~/.hermes/.env as {var_name}{_RST}")
|
||||
_dhh = display_hermes_home()
|
||||
cprint(f"\n{_DIM} ✓ Stored secret in {_dhh}/.env as {var_name}{_RST}")
|
||||
return {
|
||||
**stored,
|
||||
"skipped": False,
|
||||
@@ -183,7 +185,8 @@ def prompt_for_secret(cli, var_name: str, prompt: str, metadata=None) -> dict:
|
||||
}
|
||||
|
||||
stored = save_env_value_secure(var_name, value)
|
||||
cprint(f"\n{_DIM} ✓ Stored secret in ~/.hermes/.env as {var_name}{_RST}")
|
||||
_dhh = display_hermes_home()
|
||||
cprint(f"\n{_DIM} ✓ Stored secret in {_dhh}/.env as {var_name}{_RST}")
|
||||
return {
|
||||
**stored,
|
||||
"skipped": False,
|
||||
|
||||
@@ -135,6 +135,7 @@ def ensure_hermes_home():
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"model": "anthropic/claude-opus-4.6",
|
||||
"fallback_providers": [],
|
||||
"toolsets": ["hermes-cli"],
|
||||
"agent": {
|
||||
"max_turns": 90,
|
||||
@@ -366,6 +367,13 @@ DEFAULT_CONFIG = {
|
||||
# Never saved to sessions, logs, or trajectories.
|
||||
"prefill_messages_file": "",
|
||||
|
||||
# Skills — external skill directories for sharing skills across tools/agents.
|
||||
# Each path is expanded (~, ${VAR}) and resolved. Read-only — skill creation
|
||||
# always goes to ~/.hermes/skills/.
|
||||
"skills": {
|
||||
"external_dirs": [], # e.g. ["~/.agents/skills", "/shared/team-skills"]
|
||||
},
|
||||
|
||||
# Honcho AI-native memory -- reads ~/.honcho/config.json as single source of truth.
|
||||
# This section is only needed for hermes-specific overrides; everything else
|
||||
# (apiKey, workspace, peerName, sessions, enabled) comes from the global config.
|
||||
@@ -421,6 +429,12 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
},
|
||||
|
||||
"cron": {
|
||||
# Wrap delivered cron responses with a header (task name) and footer
|
||||
# ("The agent cannot see this message"). Set to false for clean output.
|
||||
"wrap_response": True,
|
||||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 10,
|
||||
}
|
||||
@@ -817,6 +831,20 @@ OPTIONAL_ENV_VARS = {
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"MATTERMOST_REQUIRE_MENTION": {
|
||||
"description": "Require @mention in Mattermost channels (default: true). Set to false to respond to all messages.",
|
||||
"prompt": "Require @mention in channels",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"MATTERMOST_FREE_RESPONSE_CHANNELS": {
|
||||
"description": "Comma-separated Mattermost channel IDs where bot responds without @mention",
|
||||
"prompt": "Free-response channel IDs (comma-separated)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"MATRIX_HOMESERVER": {
|
||||
"description": "Matrix homeserver URL (e.g. https://matrix.example.org)",
|
||||
"prompt": "Matrix homeserver URL",
|
||||
|
||||
+30
-4
@@ -4,7 +4,7 @@ Used by `hermes tools` and `hermes skills` for interactive checklists.
|
||||
Provides a curses multi-select with keyboard navigation, plus a
|
||||
text-based numbered fallback for terminals without curses support.
|
||||
"""
|
||||
from typing import List, Set
|
||||
from typing import Callable, List, Optional, Set
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
@@ -15,6 +15,7 @@ def curses_checklist(
|
||||
selected: Set[int],
|
||||
*,
|
||||
cancel_returns: Set[int] | None = None,
|
||||
status_fn: Optional[Callable[[Set[int]], str]] = None,
|
||||
) -> Set[int]:
|
||||
"""Curses multi-select checklist. Returns set of selected indices.
|
||||
|
||||
@@ -23,6 +24,9 @@ def curses_checklist(
|
||||
items: Display labels for each row.
|
||||
selected: Indices that start checked (pre-selected).
|
||||
cancel_returns: Returned on ESC/q. Defaults to the original *selected*.
|
||||
status_fn: Optional callback ``f(chosen_indices) -> str`` whose return
|
||||
value is rendered on the bottom row of the terminal. Use this for
|
||||
live aggregate info (e.g. estimated token counts).
|
||||
"""
|
||||
if cancel_returns is None:
|
||||
cancel_returns = set(selected)
|
||||
@@ -47,6 +51,9 @@ def curses_checklist(
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Reserve bottom row for status bar when status_fn provided
|
||||
footer_rows = 1 if status_fn else 0
|
||||
|
||||
# Header
|
||||
try:
|
||||
hattr = curses.A_BOLD
|
||||
@@ -62,7 +69,7 @@ def curses_checklist(
|
||||
pass
|
||||
|
||||
# Scrollable item list
|
||||
visible_rows = max_y - 3
|
||||
visible_rows = max_y - 3 - footer_rows
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible_rows:
|
||||
@@ -72,7 +79,7 @@ def curses_checklist(
|
||||
range(scroll_offset, min(len(items), scroll_offset + visible_rows))
|
||||
):
|
||||
y = draw_i + 3
|
||||
if y >= max_y - 1:
|
||||
if y >= max_y - 1 - footer_rows:
|
||||
break
|
||||
check = "✓" if i in chosen else " "
|
||||
arrow = "→" if i == cursor else " "
|
||||
@@ -87,6 +94,20 @@ def curses_checklist(
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
# Status bar (bottom row, right-aligned)
|
||||
if status_fn:
|
||||
try:
|
||||
status_text = status_fn(chosen)
|
||||
if status_text:
|
||||
# Right-align on the bottom row
|
||||
sx = max(0, max_x - len(status_text) - 1)
|
||||
sattr = curses.A_DIM
|
||||
if curses.has_colors():
|
||||
sattr |= curses.color_pair(3)
|
||||
stdscr.addnstr(max_y - 1, sx, status_text, max_x - sx - 1, sattr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
@@ -107,7 +128,7 @@ def curses_checklist(
|
||||
return result_holder[0] if result_holder[0] is not None else cancel_returns
|
||||
|
||||
except Exception:
|
||||
return _numbered_fallback(title, items, selected, cancel_returns)
|
||||
return _numbered_fallback(title, items, selected, cancel_returns, status_fn)
|
||||
|
||||
|
||||
def _numbered_fallback(
|
||||
@@ -115,6 +136,7 @@ def _numbered_fallback(
|
||||
items: List[str],
|
||||
selected: Set[int],
|
||||
cancel_returns: Set[int],
|
||||
status_fn: Optional[Callable[[Set[int]], str]] = None,
|
||||
) -> Set[int]:
|
||||
"""Text-based toggle fallback for terminals without curses."""
|
||||
chosen = set(selected)
|
||||
@@ -125,6 +147,10 @@ def _numbered_fallback(
|
||||
for i, label in enumerate(items):
|
||||
marker = color("[✓]", Colors.GREEN) if i in chosen else "[ ]"
|
||||
print(f" {marker} {i + 1:>2}. {label}")
|
||||
if status_fn:
|
||||
status_text = status_fn(chosen)
|
||||
if status_text:
|
||||
print(color(f"\n {status_text}", Colors.DIM))
|
||||
print()
|
||||
try:
|
||||
val = input(color(" Toggle # (or Enter to confirm): ", Colors.DIM)).strip()
|
||||
|
||||
+73
-24
@@ -10,9 +10,11 @@ import subprocess
|
||||
import shutil
|
||||
|
||||
from hermes_cli.config import get_project_root, get_hermes_home, get_env_path
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
PROJECT_ROOT = get_project_root()
|
||||
HERMES_HOME = get_hermes_home()
|
||||
_DHH = display_hermes_home() # user-facing display path (e.g. ~/.hermes or ~/.hermes/profiles/coder)
|
||||
|
||||
# Load environment variables from ~/.hermes/.env so API key checks work
|
||||
from dotenv import load_dotenv
|
||||
@@ -209,14 +211,14 @@ def run_doctor(args):
|
||||
# Check ~/.hermes/.env (primary location for user config)
|
||||
env_path = HERMES_HOME / '.env'
|
||||
if env_path.exists():
|
||||
check_ok("~/.hermes/.env file exists")
|
||||
check_ok(f"{_DHH}/.env file exists")
|
||||
|
||||
# Check for common issues
|
||||
content = env_path.read_text()
|
||||
if _has_provider_env_config(content):
|
||||
check_ok("API key or custom endpoint configured")
|
||||
else:
|
||||
check_warn("No API key found in ~/.hermes/.env")
|
||||
check_warn(f"No API key found in {_DHH}/.env")
|
||||
issues.append("Run 'hermes setup' to configure API keys")
|
||||
else:
|
||||
# Also check project root as fallback
|
||||
@@ -224,11 +226,11 @@ def run_doctor(args):
|
||||
if fallback_env.exists():
|
||||
check_ok(".env file exists (in project directory)")
|
||||
else:
|
||||
check_fail("~/.hermes/.env file missing")
|
||||
check_fail(f"{_DHH}/.env file missing")
|
||||
if should_fix:
|
||||
env_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
env_path.touch()
|
||||
check_ok("Created empty ~/.hermes/.env")
|
||||
check_ok(f"Created empty {_DHH}/.env")
|
||||
check_info("Run 'hermes setup' to configure API keys")
|
||||
fixed_count += 1
|
||||
else:
|
||||
@@ -238,7 +240,7 @@ def run_doctor(args):
|
||||
# Check ~/.hermes/config.yaml (primary) or project cli-config.yaml (fallback)
|
||||
config_path = HERMES_HOME / 'config.yaml'
|
||||
if config_path.exists():
|
||||
check_ok("~/.hermes/config.yaml exists")
|
||||
check_ok(f"{_DHH}/config.yaml exists")
|
||||
else:
|
||||
fallback_config = PROJECT_ROOT / 'cli-config.yaml'
|
||||
if fallback_config.exists():
|
||||
@@ -248,11 +250,11 @@ def run_doctor(args):
|
||||
if should_fix and example_config.exists():
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(str(example_config), str(config_path))
|
||||
check_ok("Created ~/.hermes/config.yaml from cli-config.yaml.example")
|
||||
check_ok(f"Created {_DHH}/config.yaml from cli-config.yaml.example")
|
||||
fixed_count += 1
|
||||
elif should_fix:
|
||||
check_warn("config.yaml not found and no example to copy from")
|
||||
manual_issues.append("Create ~/.hermes/config.yaml manually")
|
||||
manual_issues.append(f"Create {_DHH}/config.yaml manually")
|
||||
else:
|
||||
check_warn("config.yaml not found", "(using defaults)")
|
||||
|
||||
@@ -294,28 +296,28 @@ def run_doctor(args):
|
||||
|
||||
hermes_home = HERMES_HOME
|
||||
if hermes_home.exists():
|
||||
check_ok("~/.hermes directory exists")
|
||||
check_ok(f"{_DHH} directory exists")
|
||||
else:
|
||||
if should_fix:
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
check_ok("Created ~/.hermes directory")
|
||||
check_ok(f"Created {_DHH} directory")
|
||||
fixed_count += 1
|
||||
else:
|
||||
check_warn("~/.hermes not found", "(will be created on first use)")
|
||||
check_warn(f"{_DHH} not found", "(will be created on first use)")
|
||||
|
||||
# Check expected subdirectories
|
||||
expected_subdirs = ["cron", "sessions", "logs", "skills", "memories"]
|
||||
for subdir_name in expected_subdirs:
|
||||
subdir_path = hermes_home / subdir_name
|
||||
if subdir_path.exists():
|
||||
check_ok(f"~/.hermes/{subdir_name}/ exists")
|
||||
check_ok(f"{_DHH}/{subdir_name}/ exists")
|
||||
else:
|
||||
if should_fix:
|
||||
subdir_path.mkdir(parents=True, exist_ok=True)
|
||||
check_ok(f"Created ~/.hermes/{subdir_name}/")
|
||||
check_ok(f"Created {_DHH}/{subdir_name}/")
|
||||
fixed_count += 1
|
||||
else:
|
||||
check_warn(f"~/.hermes/{subdir_name}/ not found", "(will be created on first use)")
|
||||
check_warn(f"{_DHH}/{subdir_name}/ not found", "(will be created on first use)")
|
||||
|
||||
# Check for SOUL.md persona file
|
||||
soul_path = hermes_home / "SOUL.md"
|
||||
@@ -324,11 +326,11 @@ def run_doctor(args):
|
||||
# Check if it's just the template comments (no real content)
|
||||
lines = [l for l in content.splitlines() if l.strip() and not l.strip().startswith(("<!--", "-->", "#"))]
|
||||
if lines:
|
||||
check_ok("~/.hermes/SOUL.md exists (persona configured)")
|
||||
check_ok(f"{_DHH}/SOUL.md exists (persona configured)")
|
||||
else:
|
||||
check_info("~/.hermes/SOUL.md exists but is empty — edit it to customize personality")
|
||||
check_info(f"{_DHH}/SOUL.md exists but is empty — edit it to customize personality")
|
||||
else:
|
||||
check_warn("~/.hermes/SOUL.md not found", "(create it to give Hermes a custom personality)")
|
||||
check_warn(f"{_DHH}/SOUL.md not found", "(create it to give Hermes a custom personality)")
|
||||
if should_fix:
|
||||
soul_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
soul_path.write_text(
|
||||
@@ -337,13 +339,13 @@ def run_doctor(args):
|
||||
"You are Hermes, a helpful AI assistant.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
check_ok("Created ~/.hermes/SOUL.md with basic template")
|
||||
check_ok(f"Created {_DHH}/SOUL.md with basic template")
|
||||
fixed_count += 1
|
||||
|
||||
# Check memory directory
|
||||
memories_dir = hermes_home / "memories"
|
||||
if memories_dir.exists():
|
||||
check_ok("~/.hermes/memories/ directory exists")
|
||||
check_ok(f"{_DHH}/memories/ directory exists")
|
||||
memory_file = memories_dir / "MEMORY.md"
|
||||
user_file = memories_dir / "USER.md"
|
||||
if memory_file.exists():
|
||||
@@ -357,10 +359,10 @@ def run_doctor(args):
|
||||
else:
|
||||
check_info("USER.md not created yet (will be created when the agent first writes a memory)")
|
||||
else:
|
||||
check_warn("~/.hermes/memories/ not found", "(will be created on first use)")
|
||||
check_warn(f"{_DHH}/memories/ not found", "(will be created on first use)")
|
||||
if should_fix:
|
||||
memories_dir.mkdir(parents=True, exist_ok=True)
|
||||
check_ok("Created ~/.hermes/memories/")
|
||||
check_ok(f"Created {_DHH}/memories/")
|
||||
fixed_count += 1
|
||||
|
||||
# Check SQLite session store
|
||||
@@ -372,11 +374,11 @@ def run_doctor(args):
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM sessions")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
check_ok(f"~/.hermes/state.db exists ({count} sessions)")
|
||||
check_ok(f"{_DHH}/state.db exists ({count} sessions)")
|
||||
except Exception as e:
|
||||
check_warn(f"~/.hermes/state.db exists but has issues: {e}")
|
||||
check_warn(f"{_DHH}/state.db exists but has issues: {e}")
|
||||
else:
|
||||
check_info("~/.hermes/state.db not created yet (will be created on first session)")
|
||||
check_info(f"{_DHH}/state.db not created yet (will be created on first session)")
|
||||
|
||||
_check_gateway_service_linger(issues)
|
||||
|
||||
@@ -691,7 +693,7 @@ def run_doctor(args):
|
||||
if github_token:
|
||||
check_ok("GitHub token configured (authenticated API access)")
|
||||
else:
|
||||
check_warn("No GITHUB_TOKEN", "(60 req/hr rate limit — set in ~/.hermes/.env for better rates)")
|
||||
check_warn("No GITHUB_TOKEN", f"(60 req/hr rate limit — set in {_DHH}/.env for better rates)")
|
||||
|
||||
# =========================================================================
|
||||
# Honcho memory
|
||||
@@ -728,6 +730,53 @@ def run_doctor(args):
|
||||
except Exception as _e:
|
||||
check_warn("Honcho check failed", str(_e))
|
||||
|
||||
# =========================================================================
|
||||
# Profiles
|
||||
# =========================================================================
|
||||
try:
|
||||
from hermes_cli.profiles import list_profiles, _get_wrapper_dir, profile_exists
|
||||
import re as _re
|
||||
|
||||
named_profiles = [p for p in list_profiles() if not p.is_default]
|
||||
if named_profiles:
|
||||
print()
|
||||
print(color("◆ Profiles", Colors.CYAN, Colors.BOLD))
|
||||
check_ok(f"{len(named_profiles)} profile(s) found")
|
||||
wrapper_dir = _get_wrapper_dir()
|
||||
for p in named_profiles:
|
||||
parts = []
|
||||
if p.gateway_running:
|
||||
parts.append("gateway running")
|
||||
if p.model:
|
||||
parts.append(p.model[:30])
|
||||
if not (p.path / "config.yaml").exists():
|
||||
parts.append("⚠ missing config")
|
||||
if not (p.path / ".env").exists():
|
||||
parts.append("no .env")
|
||||
wrapper = wrapper_dir / p.name
|
||||
if not wrapper.exists():
|
||||
parts.append("no alias")
|
||||
status = ", ".join(parts) if parts else "configured"
|
||||
check_ok(f" {p.name}: {status}")
|
||||
|
||||
# Check for orphan wrappers
|
||||
if wrapper_dir.is_dir():
|
||||
for wrapper in wrapper_dir.iterdir():
|
||||
if not wrapper.is_file():
|
||||
continue
|
||||
try:
|
||||
content = wrapper.read_text()
|
||||
if "hermes -p" in content:
|
||||
_m = _re.search(r"hermes -p (\S+)", content)
|
||||
if _m and not profile_exists(_m.group(1)):
|
||||
check_warn(f"Orphan alias: {wrapper.name} → profile '{_m.group(1)}' no longer exists")
|
||||
except Exception:
|
||||
pass
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as _e:
|
||||
logger.debug("Profile health check failed: %s", _e)
|
||||
|
||||
# =========================================================================
|
||||
# Summary
|
||||
# =========================================================================
|
||||
|
||||
@@ -15,6 +15,8 @@ from pathlib import Path
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
from hermes_cli.config import get_env_value, get_hermes_home, save_env_value, is_managed, managed_error
|
||||
# display_hermes_home is imported lazily at call sites to avoid ImportError
|
||||
# when hermes_constants is cached from a pre-update version during `hermes update`.
|
||||
from hermes_cli.setup import (
|
||||
print_header, print_info, print_success, print_warning, print_error,
|
||||
prompt, prompt_choice, prompt_yes_no,
|
||||
@@ -935,7 +937,8 @@ def launchd_install(force: bool = False):
|
||||
print()
|
||||
print("Next steps:")
|
||||
print(" hermes gateway status # Check status")
|
||||
print(" tail -f ~/.hermes/logs/gateway.log # View logs")
|
||||
from hermes_constants import display_hermes_home as _dhh
|
||||
print(f" tail -f {_dhh()}/logs/gateway.log # View logs")
|
||||
|
||||
def launchd_uninstall():
|
||||
plist_path = get_launchd_plist_path()
|
||||
|
||||
+492
-8
@@ -54,6 +54,71 @@ from typing import Optional
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Profile override — MUST happen before any hermes module import.
|
||||
#
|
||||
# Many modules cache HERMES_HOME at import time (module-level constants).
|
||||
# We intercept --profile/-p from sys.argv here and set the env var so that
|
||||
# every subsequent ``os.getenv("HERMES_HOME", ...)`` resolves correctly.
|
||||
# The flag is stripped from sys.argv so argparse never sees it.
|
||||
# Falls back to ~/.hermes/active_profile for sticky default.
|
||||
# ---------------------------------------------------------------------------
|
||||
def _apply_profile_override() -> None:
|
||||
"""Pre-parse --profile/-p and set HERMES_HOME before module imports."""
|
||||
argv = sys.argv[1:]
|
||||
profile_name = None
|
||||
consume = 0
|
||||
|
||||
# 1. Check for explicit -p / --profile flag
|
||||
for i, arg in enumerate(argv):
|
||||
if arg in ("--profile", "-p") and i + 1 < len(argv):
|
||||
profile_name = argv[i + 1]
|
||||
consume = 2
|
||||
break
|
||||
elif arg.startswith("--profile="):
|
||||
profile_name = arg.split("=", 1)[1]
|
||||
consume = 1
|
||||
break
|
||||
|
||||
# 2. If no flag, check ~/.hermes/active_profile
|
||||
if profile_name is None:
|
||||
try:
|
||||
active_path = Path.home() / ".hermes" / "active_profile"
|
||||
if active_path.exists():
|
||||
name = active_path.read_text().strip()
|
||||
if name and name != "default":
|
||||
profile_name = name
|
||||
consume = 0 # don't strip anything from argv
|
||||
except (UnicodeDecodeError, OSError):
|
||||
pass # corrupted file, skip
|
||||
|
||||
# 3. If we found a profile, resolve and set HERMES_HOME
|
||||
if profile_name is not None:
|
||||
try:
|
||||
from hermes_cli.profiles import resolve_profile_env
|
||||
hermes_home = resolve_profile_env(profile_name)
|
||||
except (ValueError, FileNotFoundError) as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except Exception as exc:
|
||||
# A bug in profiles.py must NEVER prevent hermes from starting
|
||||
print(f"Warning: profile override failed ({exc}), using default", file=sys.stderr)
|
||||
return
|
||||
os.environ["HERMES_HOME"] = hermes_home
|
||||
# Strip the flag from argv so argparse doesn't choke
|
||||
if consume > 0:
|
||||
for i, arg in enumerate(argv):
|
||||
if arg in ("--profile", "-p"):
|
||||
start = i + 1 # +1 because argv is sys.argv[1:]
|
||||
sys.argv = sys.argv[:start] + sys.argv[start + consume:]
|
||||
break
|
||||
elif arg.startswith("--profile="):
|
||||
start = i + 1
|
||||
sys.argv = sys.argv[:start] + sys.argv[start + 1:]
|
||||
break
|
||||
|
||||
_apply_profile_override()
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
from hermes_cli.config import get_hermes_home
|
||||
@@ -980,6 +1045,7 @@ def _model_flow_openrouter(config, current_model=""):
|
||||
cfg["model"] = model
|
||||
model["provider"] = "openrouter"
|
||||
model["base_url"] = OPENROUTER_BASE_URL
|
||||
model["api_mode"] = "chat_completions"
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
print(f"Default model set to: {selected} (via OpenRouter)")
|
||||
@@ -1203,6 +1269,7 @@ def _model_flow_custom(config):
|
||||
cfg["model"] = model
|
||||
model["provider"] = "custom"
|
||||
model["base_url"] = effective_url
|
||||
model["api_mode"] = "chat_completions"
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
@@ -1984,6 +2051,7 @@ def _model_flow_kimi(config, current_model=""):
|
||||
cfg["model"] = model
|
||||
model["provider"] = provider_id
|
||||
model["base_url"] = effective_base
|
||||
model["api_mode"] = "chat_completions"
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
@@ -2090,6 +2158,7 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
||||
cfg["model"] = model
|
||||
model["provider"] = provider_id
|
||||
model["base_url"] = effective_base
|
||||
model["api_mode"] = "chat_completions"
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
@@ -2121,7 +2190,8 @@ def _run_anthropic_oauth_flow(save_env_value):
|
||||
):
|
||||
use_anthropic_claude_code_credentials(save_fn=save_env_value)
|
||||
print(" ✓ Claude Code credentials linked.")
|
||||
print(" Hermes will use Claude's credential store directly instead of copying a setup-token into ~/.hermes/.env.")
|
||||
from hermes_constants import display_hermes_home as _dhh_fn
|
||||
print(f" Hermes will use Claude's credential store directly instead of copying a setup-token into {_dhh_fn()}/.env.")
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -2391,6 +2461,34 @@ def cmd_uninstall(args):
|
||||
run_uninstall(args)
|
||||
|
||||
|
||||
def _clear_bytecode_cache(root: Path) -> int:
|
||||
"""Remove all __pycache__ directories under *root*.
|
||||
|
||||
Stale .pyc files can cause ImportError after code updates when Python
|
||||
loads a cached bytecode file that references names that no longer exist
|
||||
(or don't yet exist) in the updated source. Clearing them forces Python
|
||||
to recompile from the .py source on next import.
|
||||
|
||||
Returns the number of directories removed.
|
||||
"""
|
||||
removed = 0
|
||||
for dirpath, dirnames, _ in os.walk(root):
|
||||
# Skip venv / node_modules / .git entirely
|
||||
dirnames[:] = [
|
||||
d for d in dirnames
|
||||
if d not in ("venv", ".venv", "node_modules", ".git", ".worktrees")
|
||||
]
|
||||
if os.path.basename(dirpath) == "__pycache__":
|
||||
try:
|
||||
import shutil as _shutil
|
||||
_shutil.rmtree(dirpath)
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
dirnames.clear() # nothing left to recurse into
|
||||
return removed
|
||||
|
||||
|
||||
def _update_via_zip(args):
|
||||
"""Update Hermes Agent by downloading a ZIP archive.
|
||||
|
||||
@@ -2432,7 +2530,7 @@ def _update_via_zip(args):
|
||||
break
|
||||
|
||||
# Copy updated files over existing installation, preserving venv/node_modules/.git
|
||||
preserve = {'venv', 'node_modules', '.git', '__pycache__', '.env'}
|
||||
preserve = {'venv', 'node_modules', '.git', '.env'}
|
||||
update_count = 0
|
||||
for item in os.listdir(extracted):
|
||||
if item in preserve:
|
||||
@@ -2455,6 +2553,11 @@ def _update_via_zip(args):
|
||||
except Exception as e:
|
||||
print(f"✗ ZIP update failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Clear stale bytecode after ZIP extraction
|
||||
removed = _clear_bytecode_cache(PROJECT_ROOT)
|
||||
if removed:
|
||||
print(f" ✓ Cleared {removed} stale __pycache__ director{'y' if removed == 1 else 'ies'}")
|
||||
|
||||
# Reinstall Python dependencies (try .[all] first for optional extras,
|
||||
# fall back to . if extras fail — mirrors the install script behavior)
|
||||
@@ -2853,6 +2956,13 @@ def cmd_update(args):
|
||||
)
|
||||
|
||||
_invalidate_update_cache()
|
||||
|
||||
# Clear stale .pyc bytecode cache — prevents ImportError on gateway
|
||||
# restart when updated source references names that didn't exist in
|
||||
# the old bytecode (e.g. get_hermes_home added to hermes_constants).
|
||||
removed = _clear_bytecode_cache(PROJECT_ROOT)
|
||||
if removed:
|
||||
print(f" ✓ Cleared {removed} stale __pycache__ director{'y' if removed == 1 else 'ies'}")
|
||||
|
||||
# Reinstall Python dependencies (try .[all] first for optional extras,
|
||||
# fall back to . if extras fail — mirrors the install script behavior)
|
||||
@@ -2901,6 +3011,17 @@ def cmd_update(args):
|
||||
print()
|
||||
print("✓ Code updated!")
|
||||
|
||||
# After git pull, source files on disk are newer than cached Python
|
||||
# modules in this process. Reload hermes_constants so that any lazy
|
||||
# import executed below (skills sync, gateway restart) sees new
|
||||
# attributes like display_hermes_home() added since the last release.
|
||||
try:
|
||||
import importlib
|
||||
import hermes_constants as _hc
|
||||
importlib.reload(_hc)
|
||||
except Exception:
|
||||
pass # non-fatal — worst case a lazy import fails gracefully
|
||||
|
||||
# Sync bundled skills (copies new, updates changed, respects user deletions)
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
@@ -2919,7 +3040,35 @@ def cmd_update(args):
|
||||
print(" ✓ Skills are up to date")
|
||||
except Exception as e:
|
||||
logger.debug("Skills sync during update failed: %s", e)
|
||||
|
||||
|
||||
# Sync bundled skills to all other profiles
|
||||
try:
|
||||
from hermes_cli.profiles import list_profiles, get_active_profile_name, seed_profile_skills
|
||||
active = get_active_profile_name()
|
||||
other_profiles = [p for p in list_profiles() if not p.is_default and p.name != active]
|
||||
if other_profiles:
|
||||
print()
|
||||
print("→ Syncing bundled skills to other profiles...")
|
||||
for p in other_profiles:
|
||||
try:
|
||||
r = seed_profile_skills(p.path, quiet=True)
|
||||
if r:
|
||||
copied = len(r.get("copied", []))
|
||||
updated = len(r.get("updated", []))
|
||||
modified = len(r.get("user_modified", []))
|
||||
parts = []
|
||||
if copied: parts.append(f"+{copied} new")
|
||||
if updated: parts.append(f"↑{updated} updated")
|
||||
if modified: parts.append(f"~{modified} user-modified")
|
||||
status = ", ".join(parts) if parts else "up to date"
|
||||
else:
|
||||
status = "sync failed"
|
||||
print(f" {p.name}: {status}")
|
||||
except Exception as pe:
|
||||
print(f" {p.name}: error ({pe})")
|
||||
except Exception:
|
||||
pass # profiles module not available or no profiles
|
||||
|
||||
# Check for config migrations
|
||||
print()
|
||||
print("→ Checking configuration for new options...")
|
||||
@@ -3117,6 +3266,7 @@ def _coalesce_session_name_args(argv: list) -> list:
|
||||
"chat", "model", "gateway", "setup", "whatsapp", "login", "logout",
|
||||
"status", "cron", "doctor", "config", "pairing", "skills", "tools",
|
||||
"mcp", "sessions", "insights", "version", "update", "uninstall",
|
||||
"profile",
|
||||
}
|
||||
_SESSION_FLAGS = {"-c", "--continue", "-r", "--resume"}
|
||||
|
||||
@@ -3140,6 +3290,253 @@ def _coalesce_session_name_args(argv: list) -> list:
|
||||
return result
|
||||
|
||||
|
||||
def cmd_profile(args):
|
||||
"""Profile management — create, delete, list, switch, alias."""
|
||||
from hermes_cli.profiles import (
|
||||
list_profiles, create_profile, delete_profile, seed_profile_skills,
|
||||
get_active_profile, set_active_profile, get_active_profile_name,
|
||||
check_alias_collision, create_wrapper_script, remove_wrapper_script,
|
||||
_is_wrapper_dir_in_path, _get_wrapper_dir,
|
||||
)
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
action = getattr(args, "profile_action", None)
|
||||
|
||||
if action is None:
|
||||
# Bare `hermes profile` — show current profile status
|
||||
profile_name = get_active_profile_name()
|
||||
dhh = display_hermes_home()
|
||||
print(f"\nActive profile: {profile_name}")
|
||||
print(f"Path: {dhh}")
|
||||
|
||||
profiles = list_profiles()
|
||||
for p in profiles:
|
||||
if p.name == profile_name or (profile_name == "default" and p.is_default):
|
||||
if p.model:
|
||||
print(f"Model: {p.model}" + (f" ({p.provider})" if p.provider else ""))
|
||||
print(f"Gateway: {'running' if p.gateway_running else 'stopped'}")
|
||||
print(f"Skills: {p.skill_count} installed")
|
||||
if p.alias_path:
|
||||
print(f"Alias: {p.name} → hermes -p {p.name}")
|
||||
break
|
||||
print()
|
||||
return
|
||||
|
||||
if action == "list":
|
||||
profiles = list_profiles()
|
||||
active = get_active_profile_name()
|
||||
|
||||
if not profiles:
|
||||
print("No profiles found.")
|
||||
return
|
||||
|
||||
# Header
|
||||
print(f"\n {'Profile':<16} {'Model':<28} {'Gateway':<12} {'Alias'}")
|
||||
print(f" {'─' * 15} {'─' * 27} {'─' * 11} {'─' * 12}")
|
||||
|
||||
for p in profiles:
|
||||
marker = " ◆" if (p.name == active or (active == "default" and p.is_default)) else " "
|
||||
name = p.name
|
||||
model = (p.model or "—")[:26]
|
||||
gw = "running" if p.gateway_running else "stopped"
|
||||
alias = p.name if p.alias_path else "—"
|
||||
if p.is_default:
|
||||
alias = "—"
|
||||
print(f"{marker}{name:<15} {model:<28} {gw:<12} {alias}")
|
||||
print()
|
||||
|
||||
elif action == "use":
|
||||
name = args.profile_name
|
||||
try:
|
||||
set_active_profile(name)
|
||||
if name == "default":
|
||||
print(f"Switched to: default (~/.hermes)")
|
||||
else:
|
||||
print(f"Switched to: {name}")
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
elif action == "create":
|
||||
name = args.profile_name
|
||||
clone = getattr(args, "clone", False)
|
||||
clone_all = getattr(args, "clone_all", False)
|
||||
no_alias = getattr(args, "no_alias", False)
|
||||
|
||||
try:
|
||||
clone_from = getattr(args, "clone_from", None)
|
||||
|
||||
profile_dir = create_profile(
|
||||
name=name,
|
||||
clone_from=clone_from,
|
||||
clone_all=clone_all,
|
||||
clone_config=clone,
|
||||
no_alias=no_alias,
|
||||
)
|
||||
print(f"\nProfile '{name}' created at {profile_dir}")
|
||||
|
||||
if clone or clone_all:
|
||||
source_label = getattr(args, "clone_from", None) or get_active_profile_name()
|
||||
if clone_all:
|
||||
print(f"Full copy from {source_label}.")
|
||||
else:
|
||||
print(f"Cloned config, .env, SOUL.md from {source_label}.")
|
||||
|
||||
# Seed bundled skills (skip if --clone-all already copied them)
|
||||
if not clone_all:
|
||||
result = seed_profile_skills(profile_dir)
|
||||
if result:
|
||||
copied = len(result.get("copied", []))
|
||||
print(f"{copied} bundled skills synced.")
|
||||
else:
|
||||
print("⚠ Skills could not be seeded. Run `{} update` to retry.".format(name))
|
||||
|
||||
# Create wrapper alias
|
||||
if not no_alias:
|
||||
collision = check_alias_collision(name)
|
||||
if collision:
|
||||
print(f"\n⚠ Cannot create alias '{name}' — {collision}")
|
||||
print(f" Choose a custom alias: hermes profile alias {name} --name <custom>")
|
||||
print(f" Or access via flag: hermes -p {name} chat")
|
||||
else:
|
||||
wrapper_path = create_wrapper_script(name)
|
||||
if wrapper_path:
|
||||
print(f"Wrapper created: {wrapper_path}")
|
||||
if not _is_wrapper_dir_in_path():
|
||||
print(f"\n⚠ {_get_wrapper_dir()} is not in your PATH.")
|
||||
print(f' Add to your shell config (~/.bashrc or ~/.zshrc):')
|
||||
print(f' export PATH="$HOME/.local/bin:$PATH"')
|
||||
|
||||
# Next steps
|
||||
print(f"\nNext steps:")
|
||||
print(f" {name} setup Configure API keys and model")
|
||||
print(f" {name} chat Start chatting")
|
||||
print(f" {name} gateway start Start the messaging gateway")
|
||||
if clone or clone_all:
|
||||
from hermes_constants import get_hermes_home
|
||||
profile_dir_display = f"~/.hermes/profiles/{name}"
|
||||
print(f"\n Edit {profile_dir_display}/.env for different API keys")
|
||||
print(f" Edit {profile_dir_display}/SOUL.md for different personality")
|
||||
print()
|
||||
|
||||
except (ValueError, FileExistsError, FileNotFoundError) as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
elif action == "delete":
|
||||
name = args.profile_name
|
||||
yes = getattr(args, "yes", False)
|
||||
try:
|
||||
delete_profile(name, yes=yes)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
elif action == "show":
|
||||
name = args.profile_name
|
||||
from hermes_cli.profiles import get_profile_dir, profile_exists, _read_config_model, _check_gateway_running, _count_skills
|
||||
if not profile_exists(name):
|
||||
print(f"Error: Profile '{name}' does not exist.")
|
||||
sys.exit(1)
|
||||
profile_dir = get_profile_dir(name)
|
||||
model, provider = _read_config_model(profile_dir)
|
||||
gw = _check_gateway_running(profile_dir)
|
||||
skills = _count_skills(profile_dir)
|
||||
wrapper = _get_wrapper_dir() / name
|
||||
|
||||
print(f"\nProfile: {name}")
|
||||
print(f"Path: {profile_dir}")
|
||||
if model:
|
||||
print(f"Model: {model}" + (f" ({provider})" if provider else ""))
|
||||
print(f"Gateway: {'running' if gw else 'stopped'}")
|
||||
print(f"Skills: {skills}")
|
||||
print(f".env: {'exists' if (profile_dir / '.env').exists() else 'not configured'}")
|
||||
print(f"SOUL.md: {'exists' if (profile_dir / 'SOUL.md').exists() else 'not configured'}")
|
||||
if wrapper.exists():
|
||||
print(f"Alias: {wrapper}")
|
||||
print()
|
||||
|
||||
elif action == "alias":
|
||||
name = args.profile_name
|
||||
remove = getattr(args, "remove", False)
|
||||
custom_name = getattr(args, "alias_name", None)
|
||||
|
||||
from hermes_cli.profiles import profile_exists
|
||||
if not profile_exists(name):
|
||||
print(f"Error: Profile '{name}' does not exist.")
|
||||
sys.exit(1)
|
||||
|
||||
alias_name = custom_name or name
|
||||
|
||||
if remove:
|
||||
if remove_wrapper_script(alias_name):
|
||||
print(f"✓ Removed alias '{alias_name}'")
|
||||
else:
|
||||
print(f"No alias '{alias_name}' found to remove.")
|
||||
else:
|
||||
collision = check_alias_collision(alias_name)
|
||||
if collision:
|
||||
print(f"Error: {collision}")
|
||||
sys.exit(1)
|
||||
wrapper_path = create_wrapper_script(alias_name)
|
||||
if wrapper_path:
|
||||
# If custom name, write the profile name into the wrapper
|
||||
if custom_name:
|
||||
wrapper_path.write_text(f'#!/bin/sh\nexec hermes -p {name} "$@"\n')
|
||||
print(f"✓ Alias created: {wrapper_path}")
|
||||
if not _is_wrapper_dir_in_path():
|
||||
print(f"⚠ {_get_wrapper_dir()} is not in your PATH.")
|
||||
|
||||
elif action == "rename":
|
||||
from hermes_cli.profiles import rename_profile
|
||||
try:
|
||||
new_dir = rename_profile(args.old_name, args.new_name)
|
||||
print(f"\nProfile renamed: {args.old_name} → {args.new_name}")
|
||||
print(f"Path: {new_dir}\n")
|
||||
except (ValueError, FileExistsError, FileNotFoundError) as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
elif action == "export":
|
||||
from hermes_cli.profiles import export_profile
|
||||
name = args.profile_name
|
||||
output = args.output or f"{name}.tar.gz"
|
||||
try:
|
||||
result_path = export_profile(name, output)
|
||||
print(f"✓ Exported '{name}' to {result_path}")
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
elif action == "import":
|
||||
from hermes_cli.profiles import import_profile
|
||||
try:
|
||||
profile_dir = import_profile(args.archive, name=getattr(args, "import_name", None))
|
||||
name = profile_dir.name
|
||||
print(f"✓ Imported profile '{name}' at {profile_dir}")
|
||||
|
||||
# Offer to create alias
|
||||
collision = check_alias_collision(name)
|
||||
if not collision:
|
||||
wrapper_path = create_wrapper_script(name)
|
||||
if wrapper_path:
|
||||
print(f" Wrapper created: {wrapper_path}")
|
||||
print()
|
||||
except (ValueError, FileExistsError, FileNotFoundError) as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_completion(args):
|
||||
"""Print shell completion script."""
|
||||
from hermes_cli.profiles import generate_bash_completion, generate_zsh_completion
|
||||
shell = getattr(args, "shell", "bash")
|
||||
if shell == "zsh":
|
||||
print(generate_zsh_completion())
|
||||
else:
|
||||
print(generate_bash_completion())
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for hermes CLI."""
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -3774,6 +4171,16 @@ For more help on a command:
|
||||
|
||||
plugins_subparsers.add_parser("list", aliases=["ls"], help="List installed plugins")
|
||||
|
||||
plugins_enable = plugins_subparsers.add_parser(
|
||||
"enable", help="Enable a disabled plugin"
|
||||
)
|
||||
plugins_enable.add_argument("name", help="Plugin name to enable")
|
||||
|
||||
plugins_disable = plugins_subparsers.add_parser(
|
||||
"disable", help="Disable a plugin without removing it"
|
||||
)
|
||||
plugins_disable.add_argument("name", help="Plugin name to disable")
|
||||
|
||||
def cmd_plugins(args):
|
||||
from hermes_cli.plugins_cmd import plugins_command
|
||||
plugins_command(args)
|
||||
@@ -3941,16 +4348,25 @@ For more help on a command:
|
||||
# =========================================================================
|
||||
mcp_parser = subparsers.add_parser(
|
||||
"mcp",
|
||||
help="Manage MCP server connections",
|
||||
help="Manage MCP servers and run Hermes as an MCP server",
|
||||
description=(
|
||||
"Add, remove, list, test, and configure MCP server connections.\n\n"
|
||||
"Manage MCP server connections and run Hermes as an MCP server.\n\n"
|
||||
"MCP servers provide additional tools via the Model Context Protocol.\n"
|
||||
"Use 'hermes mcp add' to connect to a new server with interactive\n"
|
||||
"tool discovery. Run 'hermes mcp' with no subcommand to list servers."
|
||||
"Use 'hermes mcp add' to connect to a new server, or\n"
|
||||
"'hermes mcp serve' to expose Hermes conversations over MCP."
|
||||
),
|
||||
)
|
||||
mcp_sub = mcp_parser.add_subparsers(dest="mcp_action")
|
||||
|
||||
mcp_serve_p = mcp_sub.add_parser(
|
||||
"serve",
|
||||
help="Run Hermes as an MCP server (expose conversations to other agents)",
|
||||
)
|
||||
mcp_serve_p.add_argument(
|
||||
"-v", "--verbose", action="store_true",
|
||||
help="Enable verbose logging on stderr",
|
||||
)
|
||||
|
||||
mcp_add_p = mcp_sub.add_parser("add", help="Add an MCP server (discovery-first install)")
|
||||
mcp_add_p.add_argument("name", help="Server name (used as config key)")
|
||||
mcp_add_p.add_argument("--url", help="HTTP/SSE endpoint URL")
|
||||
@@ -4327,7 +4743,75 @@ For more help on a command:
|
||||
sys.exit(1)
|
||||
|
||||
acp_parser.set_defaults(func=cmd_acp)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# profile command
|
||||
# =========================================================================
|
||||
profile_parser = subparsers.add_parser(
|
||||
"profile",
|
||||
help="Manage profiles — multiple isolated Hermes instances",
|
||||
)
|
||||
profile_subparsers = profile_parser.add_subparsers(dest="profile_action")
|
||||
|
||||
profile_list = profile_subparsers.add_parser("list", help="List all profiles")
|
||||
profile_use = profile_subparsers.add_parser("use", help="Set sticky default profile")
|
||||
profile_use.add_argument("profile_name", help="Profile name (or 'default')")
|
||||
|
||||
profile_create = profile_subparsers.add_parser("create", help="Create a new profile")
|
||||
profile_create.add_argument("profile_name", help="Profile name (lowercase, alphanumeric)")
|
||||
profile_create.add_argument("--clone", action="store_true",
|
||||
help="Copy config.yaml, .env, SOUL.md from active profile")
|
||||
profile_create.add_argument("--clone-all", action="store_true",
|
||||
help="Full copy of active profile (all state)")
|
||||
profile_create.add_argument("--clone-from", metavar="SOURCE",
|
||||
help="Source profile to clone from (default: active)")
|
||||
profile_create.add_argument("--no-alias", action="store_true",
|
||||
help="Skip wrapper script creation")
|
||||
|
||||
profile_delete = profile_subparsers.add_parser("delete", help="Delete a profile")
|
||||
profile_delete.add_argument("profile_name", help="Profile to delete")
|
||||
profile_delete.add_argument("-y", "--yes", action="store_true",
|
||||
help="Skip confirmation prompt")
|
||||
|
||||
profile_show = profile_subparsers.add_parser("show", help="Show profile details")
|
||||
profile_show.add_argument("profile_name", help="Profile to show")
|
||||
|
||||
profile_alias = profile_subparsers.add_parser("alias", help="Manage wrapper scripts")
|
||||
profile_alias.add_argument("profile_name", help="Profile name")
|
||||
profile_alias.add_argument("--remove", action="store_true",
|
||||
help="Remove the wrapper script")
|
||||
profile_alias.add_argument("--name", dest="alias_name", metavar="NAME",
|
||||
help="Custom alias name (default: profile name)")
|
||||
|
||||
profile_rename = profile_subparsers.add_parser("rename", help="Rename a profile")
|
||||
profile_rename.add_argument("old_name", help="Current profile name")
|
||||
profile_rename.add_argument("new_name", help="New profile name")
|
||||
|
||||
profile_export = profile_subparsers.add_parser("export", help="Export a profile to archive")
|
||||
profile_export.add_argument("profile_name", help="Profile to export")
|
||||
profile_export.add_argument("-o", "--output", default=None,
|
||||
help="Output file (default: <name>.tar.gz)")
|
||||
|
||||
profile_import = profile_subparsers.add_parser("import", help="Import a profile from archive")
|
||||
profile_import.add_argument("archive", help="Path to .tar.gz archive")
|
||||
profile_import.add_argument("--name", dest="import_name", metavar="NAME",
|
||||
help="Profile name (default: inferred from archive)")
|
||||
|
||||
profile_parser.set_defaults(func=cmd_profile)
|
||||
|
||||
# =========================================================================
|
||||
# completion command
|
||||
# =========================================================================
|
||||
completion_parser = subparsers.add_parser(
|
||||
"completion",
|
||||
help="Print shell completion script (bash or zsh)",
|
||||
)
|
||||
completion_parser.add_argument(
|
||||
"shell", nargs="?", default="bash", choices=["bash", "zsh"],
|
||||
help="Shell type (default: bash)",
|
||||
)
|
||||
completion_parser.set_defaults(func=cmd_completion)
|
||||
|
||||
# =========================================================================
|
||||
# Parse and execute
|
||||
# =========================================================================
|
||||
|
||||
@@ -24,6 +24,7 @@ from hermes_cli.config import (
|
||||
get_hermes_home, # noqa: F401 — used by test mocks
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -244,7 +245,7 @@ def cmd_mcp_add(args):
|
||||
api_key = _prompt("API key / Bearer token", password=True)
|
||||
if api_key:
|
||||
save_env_value(env_key, api_key)
|
||||
_success(f"Saved to ~/.hermes/.env as {env_key}")
|
||||
_success(f"Saved to {display_hermes_home()}/.env as {env_key}")
|
||||
|
||||
# Set header with env var interpolation
|
||||
if api_key or existing_key:
|
||||
@@ -332,7 +333,7 @@ def cmd_mcp_add(args):
|
||||
_save_mcp_server(name, server_config)
|
||||
|
||||
print()
|
||||
_success(f"Saved '{name}' to ~/.hermes/config.yaml ({tool_count}/{total} tools enabled)")
|
||||
_success(f"Saved '{name}' to {display_hermes_home()}/config.yaml ({tool_count}/{total} tools enabled)")
|
||||
_info("Start a new session to use these tools.")
|
||||
|
||||
|
||||
@@ -607,6 +608,11 @@ def mcp_command(args):
|
||||
"""Main dispatcher for ``hermes mcp`` subcommands."""
|
||||
action = getattr(args, "mcp_action", None)
|
||||
|
||||
if action == "serve":
|
||||
from mcp_serve import run_mcp_server
|
||||
run_mcp_server(verbose=getattr(args, "verbose", False))
|
||||
return
|
||||
|
||||
handlers = {
|
||||
"add": cmd_mcp_add,
|
||||
"remove": cmd_mcp_remove,
|
||||
@@ -625,6 +631,7 @@ def mcp_command(args):
|
||||
# No subcommand — show list
|
||||
cmd_mcp_list()
|
||||
print(color(" Commands:", Colors.CYAN))
|
||||
_info("hermes mcp serve Run as MCP server")
|
||||
_info("hermes mcp add <name> --url <endpoint> Add an MCP server")
|
||||
_info("hermes mcp add <name> --command <cmd> Add a stdio server")
|
||||
_info("hermes mcp remove <name> Remove a server")
|
||||
|
||||
@@ -35,6 +35,8 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("google/gemini-3.1-pro-preview", ""),
|
||||
("google/gemini-3.1-flash-lite-preview", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
@@ -62,6 +64,8 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"openai/gpt-5.3-codex",
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-3-flash-preview",
|
||||
"google/gemini-3.1-pro-preview",
|
||||
"google/gemini-3.1-flash-lite-preview",
|
||||
"qwen/qwen3.5-plus-02-15",
|
||||
"qwen/qwen3.5-35b-a3b",
|
||||
"stepfun/step-3.5-flash",
|
||||
|
||||
+19
-1
@@ -68,6 +68,17 @@ def _env_enabled(name: str) -> bool:
|
||||
return os.getenv(name, "").strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _get_disabled_plugins() -> set:
|
||||
"""Read the disabled plugins list from config.yaml."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
disabled = config.get("plugins", {}).get("disabled", [])
|
||||
return set(disabled) if isinstance(disabled, list) else set()
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -199,8 +210,15 @@ class PluginManager:
|
||||
# 3. Pip / entry-point plugins
|
||||
manifests.extend(self._scan_entry_points())
|
||||
|
||||
# Load each manifest
|
||||
# Load each manifest (skip user-disabled plugins)
|
||||
disabled = _get_disabled_plugins()
|
||||
for manifest in manifests:
|
||||
if manifest.name in disabled:
|
||||
loaded = LoadedPlugin(manifest=manifest, enabled=False)
|
||||
loaded.error = "disabled via config"
|
||||
self._plugins[manifest.name] = loaded
|
||||
logger.debug("Skipping disabled plugin '%s'", manifest.name)
|
||||
continue
|
||||
self._load_plugin(manifest)
|
||||
|
||||
if manifests:
|
||||
|
||||
+153
-2
@@ -374,6 +374,73 @@ def cmd_remove(name: str) -> None:
|
||||
_display_removed(name, plugins_dir)
|
||||
|
||||
|
||||
def _get_disabled_set() -> set:
|
||||
"""Read the disabled plugins set from config.yaml."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
disabled = config.get("plugins", {}).get("disabled", [])
|
||||
return set(disabled) if isinstance(disabled, list) else set()
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
def _save_disabled_set(disabled: set) -> None:
|
||||
"""Write the disabled plugins list to config.yaml."""
|
||||
from hermes_cli.config import load_config, save_config
|
||||
config = load_config()
|
||||
if "plugins" not in config:
|
||||
config["plugins"] = {}
|
||||
config["plugins"]["disabled"] = sorted(disabled)
|
||||
save_config(config)
|
||||
|
||||
|
||||
def cmd_enable(name: str) -> None:
|
||||
"""Enable a previously disabled plugin."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
# Verify the plugin exists
|
||||
target = plugins_dir / name
|
||||
if not target.is_dir():
|
||||
console.print(f"[red]Plugin '{name}' is not installed.[/red]")
|
||||
sys.exit(1)
|
||||
|
||||
disabled = _get_disabled_set()
|
||||
if name not in disabled:
|
||||
console.print(f"[dim]Plugin '{name}' is already enabled.[/dim]")
|
||||
return
|
||||
|
||||
disabled.discard(name)
|
||||
_save_disabled_set(disabled)
|
||||
console.print(f"[green]✓[/green] Plugin [bold]{name}[/bold] enabled. Takes effect on next session.")
|
||||
|
||||
|
||||
def cmd_disable(name: str) -> None:
|
||||
"""Disable a plugin without removing it."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
# Verify the plugin exists
|
||||
target = plugins_dir / name
|
||||
if not target.is_dir():
|
||||
console.print(f"[red]Plugin '{name}' is not installed.[/red]")
|
||||
sys.exit(1)
|
||||
|
||||
disabled = _get_disabled_set()
|
||||
if name in disabled:
|
||||
console.print(f"[dim]Plugin '{name}' is already disabled.[/dim]")
|
||||
return
|
||||
|
||||
disabled.add(name)
|
||||
_save_disabled_set(disabled)
|
||||
console.print(f"[yellow]⊘[/yellow] Plugin [bold]{name}[/bold] disabled. Takes effect on next session.")
|
||||
|
||||
|
||||
def cmd_list() -> None:
|
||||
"""List installed plugins."""
|
||||
from rich.console import Console
|
||||
@@ -393,8 +460,11 @@ def cmd_list() -> None:
|
||||
console.print("[dim]Install with:[/dim] hermes plugins install owner/repo")
|
||||
return
|
||||
|
||||
disabled = _get_disabled_set()
|
||||
|
||||
table = Table(title="Installed Plugins", show_lines=False)
|
||||
table.add_column("Name", style="bold")
|
||||
table.add_column("Status")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("Description")
|
||||
table.add_column("Source", style="dim")
|
||||
@@ -420,11 +490,86 @@ def cmd_list() -> None:
|
||||
if (d / ".git").exists():
|
||||
source = "git"
|
||||
|
||||
table.add_row(name, str(version), description, source)
|
||||
is_disabled = name in disabled or d.name in disabled
|
||||
status = "[red]disabled[/red]" if is_disabled else "[green]enabled[/green]"
|
||||
table.add_row(name, status, str(version), description, source)
|
||||
|
||||
console.print()
|
||||
console.print(table)
|
||||
console.print()
|
||||
console.print("[dim]Interactive toggle:[/dim] hermes plugins")
|
||||
console.print("[dim]Enable/disable:[/dim] hermes plugins enable/disable <name>")
|
||||
|
||||
|
||||
def cmd_toggle() -> None:
|
||||
"""Interactive curses checklist to enable/disable installed plugins."""
|
||||
from rich.console import Console
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
dirs = sorted(d for d in plugins_dir.iterdir() if d.is_dir())
|
||||
if not dirs:
|
||||
console.print("[dim]No plugins installed.[/dim]")
|
||||
console.print("[dim]Install with:[/dim] hermes plugins install owner/repo")
|
||||
return
|
||||
|
||||
disabled = _get_disabled_set()
|
||||
|
||||
# Build items list: "name — description" for display
|
||||
names = []
|
||||
labels = []
|
||||
selected = set()
|
||||
|
||||
for i, d in enumerate(dirs):
|
||||
manifest_file = d / "plugin.yaml"
|
||||
name = d.name
|
||||
description = ""
|
||||
|
||||
if manifest_file.exists() and yaml:
|
||||
try:
|
||||
with open(manifest_file) as f:
|
||||
manifest = yaml.safe_load(f) or {}
|
||||
name = manifest.get("name", d.name)
|
||||
description = manifest.get("description", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
names.append(name)
|
||||
label = f"{name} — {description}" if description else name
|
||||
labels.append(label)
|
||||
|
||||
if name not in disabled and d.name not in disabled:
|
||||
selected.add(i)
|
||||
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
result = curses_checklist(
|
||||
title="Plugins — toggle enabled/disabled",
|
||||
items=labels,
|
||||
selected=selected,
|
||||
)
|
||||
|
||||
# Compute new disabled set from deselected items
|
||||
new_disabled = set()
|
||||
for i, name in enumerate(names):
|
||||
if i not in result:
|
||||
new_disabled.add(name)
|
||||
|
||||
if new_disabled != disabled:
|
||||
_save_disabled_set(new_disabled)
|
||||
enabled_count = len(names) - len(new_disabled)
|
||||
console.print(
|
||||
f"\n[green]✓[/green] {enabled_count} enabled, {len(new_disabled)} disabled. "
|
||||
f"Takes effect on next session."
|
||||
)
|
||||
else:
|
||||
console.print("\n[dim]No changes.[/dim]")
|
||||
|
||||
|
||||
def plugins_command(args) -> None:
|
||||
@@ -437,8 +582,14 @@ def plugins_command(args) -> None:
|
||||
cmd_update(args.name)
|
||||
elif action in ("remove", "rm", "uninstall"):
|
||||
cmd_remove(args.name)
|
||||
elif action in ("list", "ls") or action is None:
|
||||
elif action == "enable":
|
||||
cmd_enable(args.name)
|
||||
elif action == "disable":
|
||||
cmd_disable(args.name)
|
||||
elif action in ("list", "ls"):
|
||||
cmd_list()
|
||||
elif action is None:
|
||||
cmd_toggle()
|
||||
else:
|
||||
from rich.console import Console
|
||||
|
||||
|
||||
@@ -0,0 +1,906 @@
|
||||
"""
|
||||
Profile management for multiple isolated Hermes instances.
|
||||
|
||||
Each profile is a fully independent HERMES_HOME directory with its own
|
||||
config.yaml, .env, memory, sessions, skills, gateway, cron, and logs.
|
||||
Profiles live under ``~/.hermes/profiles/<name>/`` by default.
|
||||
|
||||
The "default" profile is ``~/.hermes`` itself — backward compatible,
|
||||
zero migration needed.
|
||||
|
||||
Usage::
|
||||
|
||||
hermes profile create coder # fresh profile + bundled skills
|
||||
hermes profile create coder --clone # also copy config, .env, SOUL.md
|
||||
hermes profile create coder --clone-all # full copy of source profile
|
||||
coder chat # use via wrapper alias
|
||||
hermes -p coder chat # or via flag
|
||||
hermes profile use coder # set as sticky default
|
||||
hermes profile delete coder # remove profile + alias + service
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
_PROFILE_ID_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
|
||||
# Directories bootstrapped inside every new profile
|
||||
_PROFILE_DIRS = [
|
||||
"memories",
|
||||
"sessions",
|
||||
"skills",
|
||||
"skins",
|
||||
"logs",
|
||||
"plans",
|
||||
"workspace",
|
||||
"cron",
|
||||
]
|
||||
|
||||
# Files copied during --clone (if they exist in the source)
|
||||
_CLONE_CONFIG_FILES = [
|
||||
"config.yaml",
|
||||
".env",
|
||||
"SOUL.md",
|
||||
]
|
||||
|
||||
# Runtime files stripped after --clone-all (shouldn't carry over)
|
||||
_CLONE_ALL_STRIP = [
|
||||
"gateway.pid",
|
||||
"gateway_state.json",
|
||||
"processes.json",
|
||||
]
|
||||
|
||||
# Names that cannot be used as profile aliases
|
||||
_RESERVED_NAMES = frozenset({
|
||||
"hermes", "default", "test", "tmp", "root", "sudo",
|
||||
})
|
||||
|
||||
# Hermes subcommands that cannot be used as profile names/aliases
|
||||
_HERMES_SUBCOMMANDS = frozenset({
|
||||
"chat", "model", "gateway", "setup", "whatsapp", "login", "logout",
|
||||
"status", "cron", "doctor", "config", "pairing", "skills", "tools",
|
||||
"mcp", "sessions", "insights", "version", "update", "uninstall",
|
||||
"profile", "plugins", "honcho", "acp",
|
||||
})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Path helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_profiles_root() -> Path:
|
||||
"""Return the directory where named profiles are stored.
|
||||
|
||||
Always ``~/.hermes/profiles/`` — anchored to the user's home,
|
||||
NOT to the current HERMES_HOME (which may itself be a profile).
|
||||
This ensures ``coder profile list`` can see all profiles.
|
||||
"""
|
||||
return Path.home() / ".hermes" / "profiles"
|
||||
|
||||
|
||||
def _get_default_hermes_home() -> Path:
|
||||
"""Return the default (pre-profile) HERMES_HOME path."""
|
||||
return Path.home() / ".hermes"
|
||||
|
||||
|
||||
def _get_active_profile_path() -> Path:
|
||||
"""Return the path to the sticky active_profile file."""
|
||||
return _get_default_hermes_home() / "active_profile"
|
||||
|
||||
|
||||
def _get_wrapper_dir() -> Path:
|
||||
"""Return the directory for wrapper scripts."""
|
||||
return Path.home() / ".local" / "bin"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def validate_profile_name(name: str) -> None:
|
||||
"""Raise ``ValueError`` if *name* is not a valid profile identifier."""
|
||||
if name == "default":
|
||||
return # special alias for ~/.hermes
|
||||
if not _PROFILE_ID_RE.match(name):
|
||||
raise ValueError(
|
||||
f"Invalid profile name {name!r}. Must match "
|
||||
f"[a-z0-9][a-z0-9_-]{{0,63}}"
|
||||
)
|
||||
|
||||
|
||||
def get_profile_dir(name: str) -> Path:
|
||||
"""Resolve a profile name to its HERMES_HOME directory."""
|
||||
if name == "default":
|
||||
return _get_default_hermes_home()
|
||||
return _get_profiles_root() / name
|
||||
|
||||
|
||||
def profile_exists(name: str) -> bool:
|
||||
"""Check whether a profile directory exists."""
|
||||
if name == "default":
|
||||
return True
|
||||
return get_profile_dir(name).is_dir()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Alias / wrapper script management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def check_alias_collision(name: str) -> Optional[str]:
|
||||
"""Return a human-readable collision message, or None if the name is safe.
|
||||
|
||||
Checks: reserved names, hermes subcommands, existing binaries in PATH.
|
||||
"""
|
||||
if name in _RESERVED_NAMES:
|
||||
return f"'{name}' is a reserved name"
|
||||
if name in _HERMES_SUBCOMMANDS:
|
||||
return f"'{name}' conflicts with a hermes subcommand"
|
||||
|
||||
# Check existing commands in PATH
|
||||
wrapper_dir = _get_wrapper_dir()
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["which", name], capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
existing_path = result.stdout.strip()
|
||||
# Allow overwriting our own wrappers
|
||||
if existing_path == str(wrapper_dir / name):
|
||||
try:
|
||||
content = (wrapper_dir / name).read_text()
|
||||
if "hermes -p" in content:
|
||||
return None # it's our wrapper, safe to overwrite
|
||||
except Exception:
|
||||
pass
|
||||
return f"'{name}' conflicts with an existing command ({existing_path})"
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
|
||||
return None # safe
|
||||
|
||||
|
||||
def _is_wrapper_dir_in_path() -> bool:
|
||||
"""Check if ~/.local/bin is in PATH."""
|
||||
wrapper_dir = str(_get_wrapper_dir())
|
||||
return wrapper_dir in os.environ.get("PATH", "").split(os.pathsep)
|
||||
|
||||
|
||||
def create_wrapper_script(name: str) -> Optional[Path]:
|
||||
"""Create a shell wrapper script at ~/.local/bin/<name>.
|
||||
|
||||
Returns the path to the created wrapper, or None if creation failed.
|
||||
"""
|
||||
wrapper_dir = _get_wrapper_dir()
|
||||
try:
|
||||
wrapper_dir.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
print(f"⚠ Could not create {wrapper_dir}: {e}")
|
||||
return None
|
||||
|
||||
wrapper_path = wrapper_dir / name
|
||||
try:
|
||||
wrapper_path.write_text(f'#!/bin/sh\nexec hermes -p {name} "$@"\n')
|
||||
wrapper_path.chmod(wrapper_path.stat().st_mode | stat.S_IEXEC | stat.S_IXGRP | stat.S_IXOTH)
|
||||
return wrapper_path
|
||||
except OSError as e:
|
||||
print(f"⚠ Could not create wrapper at {wrapper_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def remove_wrapper_script(name: str) -> bool:
|
||||
"""Remove the wrapper script for a profile. Returns True if removed."""
|
||||
wrapper_path = _get_wrapper_dir() / name
|
||||
if wrapper_path.exists():
|
||||
try:
|
||||
# Verify it's our wrapper before removing
|
||||
content = wrapper_path.read_text()
|
||||
if "hermes -p" in content:
|
||||
wrapper_path.unlink()
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ProfileInfo
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class ProfileInfo:
|
||||
"""Summary information about a profile."""
|
||||
name: str
|
||||
path: Path
|
||||
is_default: bool
|
||||
gateway_running: bool
|
||||
model: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
has_env: bool = False
|
||||
skill_count: int = 0
|
||||
alias_path: Optional[Path] = None
|
||||
|
||||
|
||||
def _read_config_model(profile_dir: Path) -> tuple:
|
||||
"""Read model/provider from a profile's config.yaml. Returns (model, provider)."""
|
||||
config_path = profile_dir / "config.yaml"
|
||||
if not config_path.exists():
|
||||
return None, None
|
||||
try:
|
||||
import yaml
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
model_cfg = cfg.get("model", {})
|
||||
if isinstance(model_cfg, str):
|
||||
return model_cfg, None
|
||||
if isinstance(model_cfg, dict):
|
||||
return model_cfg.get("model"), model_cfg.get("provider")
|
||||
return None, None
|
||||
except Exception:
|
||||
return None, None
|
||||
|
||||
|
||||
def _check_gateway_running(profile_dir: Path) -> bool:
|
||||
"""Check if a gateway is running for a given profile directory."""
|
||||
pid_file = profile_dir / "gateway.pid"
|
||||
if not pid_file.exists():
|
||||
return False
|
||||
try:
|
||||
raw = pid_file.read_text().strip()
|
||||
if not raw:
|
||||
return False
|
||||
data = json.loads(raw) if raw.startswith("{") else {"pid": int(raw)}
|
||||
pid = int(data["pid"])
|
||||
os.kill(pid, 0) # existence check
|
||||
return True
|
||||
except (json.JSONDecodeError, KeyError, ValueError, TypeError,
|
||||
ProcessLookupError, PermissionError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _count_skills(profile_dir: Path) -> int:
|
||||
"""Count installed skills in a profile."""
|
||||
skills_dir = profile_dir / "skills"
|
||||
if not skills_dir.is_dir():
|
||||
return 0
|
||||
count = 0
|
||||
for md in skills_dir.rglob("SKILL.md"):
|
||||
if "/.hub/" not in str(md) and "/.git/" not in str(md):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CRUD operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def list_profiles() -> List[ProfileInfo]:
|
||||
"""Return info for all profiles, including the default."""
|
||||
profiles = []
|
||||
wrapper_dir = _get_wrapper_dir()
|
||||
|
||||
# Default profile
|
||||
default_home = _get_default_hermes_home()
|
||||
if default_home.is_dir():
|
||||
model, provider = _read_config_model(default_home)
|
||||
profiles.append(ProfileInfo(
|
||||
name="default",
|
||||
path=default_home,
|
||||
is_default=True,
|
||||
gateway_running=_check_gateway_running(default_home),
|
||||
model=model,
|
||||
provider=provider,
|
||||
has_env=(default_home / ".env").exists(),
|
||||
skill_count=_count_skills(default_home),
|
||||
))
|
||||
|
||||
# Named profiles
|
||||
profiles_root = _get_profiles_root()
|
||||
if profiles_root.is_dir():
|
||||
for entry in sorted(profiles_root.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
name = entry.name
|
||||
if not _PROFILE_ID_RE.match(name):
|
||||
continue
|
||||
model, provider = _read_config_model(entry)
|
||||
alias_path = wrapper_dir / name
|
||||
profiles.append(ProfileInfo(
|
||||
name=name,
|
||||
path=entry,
|
||||
is_default=False,
|
||||
gateway_running=_check_gateway_running(entry),
|
||||
model=model,
|
||||
provider=provider,
|
||||
has_env=(entry / ".env").exists(),
|
||||
skill_count=_count_skills(entry),
|
||||
alias_path=alias_path if alias_path.exists() else None,
|
||||
))
|
||||
|
||||
return profiles
|
||||
|
||||
|
||||
def create_profile(
|
||||
name: str,
|
||||
clone_from: Optional[str] = None,
|
||||
clone_all: bool = False,
|
||||
clone_config: bool = False,
|
||||
no_alias: bool = False,
|
||||
) -> Path:
|
||||
"""Create a new profile directory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name:
|
||||
Profile identifier (lowercase, alphanumeric, hyphens, underscores).
|
||||
clone_from:
|
||||
Source profile to clone from. If ``None`` and clone_config/clone_all
|
||||
is True, defaults to the currently active profile.
|
||||
clone_all:
|
||||
If True, do a full copytree of the source (all state).
|
||||
clone_config:
|
||||
If True, copy only config files (config.yaml, .env, SOUL.md).
|
||||
no_alias:
|
||||
If True, skip wrapper script creation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The newly created profile directory.
|
||||
"""
|
||||
validate_profile_name(name)
|
||||
|
||||
if name == "default":
|
||||
raise ValueError(
|
||||
"Cannot create a profile named 'default' — it is the built-in profile (~/.hermes)."
|
||||
)
|
||||
|
||||
profile_dir = get_profile_dir(name)
|
||||
if profile_dir.exists():
|
||||
raise FileExistsError(f"Profile '{name}' already exists at {profile_dir}")
|
||||
|
||||
# Resolve clone source
|
||||
source_dir = None
|
||||
if clone_from is not None or clone_all or clone_config:
|
||||
if clone_from is None:
|
||||
# Default: clone from active profile
|
||||
from hermes_constants import get_hermes_home
|
||||
source_dir = get_hermes_home()
|
||||
else:
|
||||
validate_profile_name(clone_from)
|
||||
source_dir = get_profile_dir(clone_from)
|
||||
if not source_dir.is_dir():
|
||||
raise FileNotFoundError(
|
||||
f"Source profile '{clone_from or 'active'}' does not exist at {source_dir}"
|
||||
)
|
||||
|
||||
if clone_all and source_dir:
|
||||
# Full copy of source profile
|
||||
shutil.copytree(source_dir, profile_dir)
|
||||
# Strip runtime files
|
||||
for stale in _CLONE_ALL_STRIP:
|
||||
(profile_dir / stale).unlink(missing_ok=True)
|
||||
else:
|
||||
# Bootstrap directory structure
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
for subdir in _PROFILE_DIRS:
|
||||
(profile_dir / subdir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Clone config files from source
|
||||
if source_dir is not None:
|
||||
for filename in _CLONE_CONFIG_FILES:
|
||||
src = source_dir / filename
|
||||
if src.exists():
|
||||
shutil.copy2(src, profile_dir / filename)
|
||||
|
||||
return profile_dir
|
||||
|
||||
|
||||
def seed_profile_skills(profile_dir: Path, quiet: bool = False) -> Optional[dict]:
|
||||
"""Seed bundled skills into a profile via subprocess.
|
||||
|
||||
Uses subprocess because sync_skills() caches HERMES_HOME at module level.
|
||||
Returns the sync result dict, or None on failure.
|
||||
"""
|
||||
project_root = Path(__file__).parent.parent.resolve()
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-c",
|
||||
"import json; from tools.skills_sync import sync_skills; "
|
||||
"r = sync_skills(quiet=True); print(json.dumps(r))"],
|
||||
env={**os.environ, "HERMES_HOME": str(profile_dir)},
|
||||
cwd=str(project_root),
|
||||
capture_output=True, text=True, timeout=60,
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return json.loads(result.stdout.strip())
|
||||
if not quiet:
|
||||
print(f"⚠ Skill seeding returned exit code {result.returncode}")
|
||||
if result.stderr.strip():
|
||||
print(f" {result.stderr.strip()[:200]}")
|
||||
return None
|
||||
except subprocess.TimeoutExpired:
|
||||
if not quiet:
|
||||
print("⚠ Skill seeding timed out (60s)")
|
||||
return None
|
||||
except Exception as e:
|
||||
if not quiet:
|
||||
print(f"⚠ Skill seeding failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def delete_profile(name: str, yes: bool = False) -> Path:
|
||||
"""Delete a profile, its wrapper script, and its gateway service.
|
||||
|
||||
Stops the gateway if running. Disables systemd/launchd service first
|
||||
to prevent auto-restart.
|
||||
|
||||
Returns the path that was removed.
|
||||
"""
|
||||
validate_profile_name(name)
|
||||
|
||||
if name == "default":
|
||||
raise ValueError(
|
||||
"Cannot delete the default profile (~/.hermes).\n"
|
||||
"To remove everything, use: hermes uninstall"
|
||||
)
|
||||
|
||||
profile_dir = get_profile_dir(name)
|
||||
if not profile_dir.is_dir():
|
||||
raise FileNotFoundError(f"Profile '{name}' does not exist.")
|
||||
|
||||
# Show what will be deleted
|
||||
model, provider = _read_config_model(profile_dir)
|
||||
gw_running = _check_gateway_running(profile_dir)
|
||||
skill_count = _count_skills(profile_dir)
|
||||
|
||||
print(f"\nProfile: {name}")
|
||||
print(f"Path: {profile_dir}")
|
||||
if model:
|
||||
print(f"Model: {model}" + (f" ({provider})" if provider else ""))
|
||||
if skill_count:
|
||||
print(f"Skills: {skill_count}")
|
||||
|
||||
items = [
|
||||
"All config, API keys, memories, sessions, skills, cron jobs",
|
||||
]
|
||||
|
||||
# Check for service
|
||||
from hermes_cli.gateway import _profile_suffix, get_service_name
|
||||
wrapper_path = _get_wrapper_dir() / name
|
||||
has_wrapper = wrapper_path.exists()
|
||||
if has_wrapper:
|
||||
items.append(f"Command alias ({wrapper_path})")
|
||||
|
||||
print(f"\nThis will permanently delete:")
|
||||
for item in items:
|
||||
print(f" • {item}")
|
||||
if gw_running:
|
||||
print(f" ⚠ Gateway is running — it will be stopped.")
|
||||
|
||||
# Confirmation
|
||||
if not yes:
|
||||
print()
|
||||
try:
|
||||
confirm = input(f"Type '{name}' to confirm: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nCancelled.")
|
||||
return profile_dir
|
||||
if confirm != name:
|
||||
print("Cancelled.")
|
||||
return profile_dir
|
||||
|
||||
# 1. Disable service (prevents auto-restart)
|
||||
_cleanup_gateway_service(name, profile_dir)
|
||||
|
||||
# 2. Stop running gateway
|
||||
if gw_running:
|
||||
_stop_gateway_process(profile_dir)
|
||||
|
||||
# 3. Remove wrapper script
|
||||
if has_wrapper:
|
||||
if remove_wrapper_script(name):
|
||||
print(f"✓ Removed {wrapper_path}")
|
||||
|
||||
# 4. Remove profile directory
|
||||
try:
|
||||
shutil.rmtree(profile_dir)
|
||||
print(f"✓ Removed {profile_dir}")
|
||||
except Exception as e:
|
||||
print(f"⚠ Could not remove {profile_dir}: {e}")
|
||||
|
||||
# 5. Clear active_profile if it pointed to this profile
|
||||
try:
|
||||
active = get_active_profile()
|
||||
if active == name:
|
||||
set_active_profile("default")
|
||||
print("✓ Active profile reset to default")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print(f"\nProfile '{name}' deleted.")
|
||||
return profile_dir
|
||||
|
||||
|
||||
def _cleanup_gateway_service(name: str, profile_dir: Path) -> None:
|
||||
"""Disable and remove systemd/launchd service for a profile."""
|
||||
import platform as _platform
|
||||
|
||||
# Derive service name for this profile
|
||||
# Temporarily set HERMES_HOME so _profile_suffix resolves correctly
|
||||
old_home = os.environ.get("HERMES_HOME")
|
||||
try:
|
||||
os.environ["HERMES_HOME"] = str(profile_dir)
|
||||
from hermes_cli.gateway import get_service_name, get_launchd_plist_path
|
||||
|
||||
if _platform.system() == "Linux":
|
||||
svc_name = get_service_name()
|
||||
svc_file = Path.home() / ".config" / "systemd" / "user" / f"{svc_name}.service"
|
||||
if svc_file.exists():
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "disable", svc_name],
|
||||
capture_output=True, check=False, timeout=10,
|
||||
)
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "stop", svc_name],
|
||||
capture_output=True, check=False, timeout=10,
|
||||
)
|
||||
svc_file.unlink(missing_ok=True)
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
capture_output=True, check=False, timeout=10,
|
||||
)
|
||||
print(f"✓ Service {svc_name} removed")
|
||||
|
||||
elif _platform.system() == "Darwin":
|
||||
plist_path = get_launchd_plist_path()
|
||||
if plist_path.exists():
|
||||
subprocess.run(
|
||||
["launchctl", "unload", str(plist_path)],
|
||||
capture_output=True, check=False, timeout=10,
|
||||
)
|
||||
plist_path.unlink(missing_ok=True)
|
||||
print(f"✓ Launchd service removed")
|
||||
except Exception as e:
|
||||
print(f"⚠ Service cleanup: {e}")
|
||||
finally:
|
||||
if old_home is not None:
|
||||
os.environ["HERMES_HOME"] = old_home
|
||||
elif "HERMES_HOME" in os.environ:
|
||||
del os.environ["HERMES_HOME"]
|
||||
|
||||
|
||||
def _stop_gateway_process(profile_dir: Path) -> None:
|
||||
"""Stop a running gateway process via its PID file."""
|
||||
import signal as _signal
|
||||
import time as _time
|
||||
|
||||
pid_file = profile_dir / "gateway.pid"
|
||||
if not pid_file.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
raw = pid_file.read_text().strip()
|
||||
data = json.loads(raw) if raw.startswith("{") else {"pid": int(raw)}
|
||||
pid = int(data["pid"])
|
||||
os.kill(pid, _signal.SIGTERM)
|
||||
# Wait up to 10s for graceful shutdown
|
||||
for _ in range(20):
|
||||
_time.sleep(0.5)
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
print(f"✓ Gateway stopped (PID {pid})")
|
||||
return
|
||||
# Force kill
|
||||
try:
|
||||
os.kill(pid, _signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
print(f"✓ Gateway force-stopped (PID {pid})")
|
||||
except (ProcessLookupError, PermissionError):
|
||||
print("✓ Gateway already stopped")
|
||||
except Exception as e:
|
||||
print(f"⚠ Could not stop gateway: {e}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Active profile (sticky default)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_active_profile() -> str:
|
||||
"""Read the sticky active profile name.
|
||||
|
||||
Returns ``"default"`` if no active_profile file exists or it's empty.
|
||||
"""
|
||||
path = _get_active_profile_path()
|
||||
try:
|
||||
name = path.read_text().strip()
|
||||
if not name:
|
||||
return "default"
|
||||
return name
|
||||
except (FileNotFoundError, UnicodeDecodeError, OSError):
|
||||
return "default"
|
||||
|
||||
|
||||
def set_active_profile(name: str) -> None:
|
||||
"""Set the sticky active profile.
|
||||
|
||||
Writes to ``~/.hermes/active_profile``. Use ``"default"`` to clear.
|
||||
"""
|
||||
validate_profile_name(name)
|
||||
if name != "default" and not profile_exists(name):
|
||||
raise FileNotFoundError(
|
||||
f"Profile '{name}' does not exist. "
|
||||
f"Create it with: hermes profile create {name}"
|
||||
)
|
||||
|
||||
path = _get_active_profile_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if name == "default":
|
||||
# Remove the file to indicate default
|
||||
path.unlink(missing_ok=True)
|
||||
else:
|
||||
# Atomic write
|
||||
tmp = path.with_suffix(".tmp")
|
||||
tmp.write_text(name + "\n")
|
||||
tmp.replace(path)
|
||||
|
||||
|
||||
def get_active_profile_name() -> str:
|
||||
"""Infer the current profile name from HERMES_HOME.
|
||||
|
||||
Returns ``"default"`` if HERMES_HOME is not set or points to ``~/.hermes``.
|
||||
Returns the profile name if HERMES_HOME points into ``~/.hermes/profiles/<name>``.
|
||||
Returns ``"custom"`` if HERMES_HOME is set to an unrecognized path.
|
||||
"""
|
||||
from hermes_constants import get_hermes_home
|
||||
hermes_home = get_hermes_home()
|
||||
resolved = hermes_home.resolve()
|
||||
|
||||
default_resolved = _get_default_hermes_home().resolve()
|
||||
if resolved == default_resolved:
|
||||
return "default"
|
||||
|
||||
profiles_root = _get_profiles_root().resolve()
|
||||
try:
|
||||
rel = resolved.relative_to(profiles_root)
|
||||
parts = rel.parts
|
||||
if len(parts) == 1 and _PROFILE_ID_RE.match(parts[0]):
|
||||
return parts[0]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return "custom"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export / Import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def export_profile(name: str, output_path: str) -> Path:
|
||||
"""Export a profile to a tar.gz archive.
|
||||
|
||||
Returns the output file path.
|
||||
"""
|
||||
validate_profile_name(name)
|
||||
profile_dir = get_profile_dir(name)
|
||||
if not profile_dir.is_dir():
|
||||
raise FileNotFoundError(f"Profile '{name}' does not exist.")
|
||||
|
||||
output = Path(output_path)
|
||||
# shutil.make_archive wants the base name without extension
|
||||
base = str(output).removesuffix(".tar.gz").removesuffix(".tgz")
|
||||
result = shutil.make_archive(base, "gztar", str(profile_dir.parent), name)
|
||||
return Path(result)
|
||||
|
||||
|
||||
def import_profile(archive_path: str, name: Optional[str] = None) -> Path:
|
||||
"""Import a profile from a tar.gz archive.
|
||||
|
||||
If *name* is not given, infers it from the archive's top-level directory.
|
||||
Returns the imported profile directory.
|
||||
"""
|
||||
import tarfile
|
||||
|
||||
archive = Path(archive_path)
|
||||
if not archive.exists():
|
||||
raise FileNotFoundError(f"Archive not found: {archive}")
|
||||
|
||||
# Peek at the archive to find the top-level directory name
|
||||
with tarfile.open(archive, "r:gz") as tf:
|
||||
top_dirs = {m.name.split("/")[0] for m in tf.getmembers() if "/" in m.name}
|
||||
if not top_dirs:
|
||||
top_dirs = {m.name for m in tf.getmembers() if m.isdir()}
|
||||
|
||||
inferred_name = name or (top_dirs.pop() if len(top_dirs) == 1 else None)
|
||||
if not inferred_name:
|
||||
raise ValueError(
|
||||
"Cannot determine profile name from archive. "
|
||||
"Specify it explicitly: hermes profile import <archive> --name <name>"
|
||||
)
|
||||
|
||||
validate_profile_name(inferred_name)
|
||||
profile_dir = get_profile_dir(inferred_name)
|
||||
if profile_dir.exists():
|
||||
raise FileExistsError(f"Profile '{inferred_name}' already exists at {profile_dir}")
|
||||
|
||||
profiles_root = _get_profiles_root()
|
||||
profiles_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shutil.unpack_archive(str(archive), str(profiles_root))
|
||||
|
||||
# If the archive extracted under a different name, rename
|
||||
extracted = profiles_root / (top_dirs.pop() if top_dirs else inferred_name)
|
||||
if extracted != profile_dir and extracted.exists():
|
||||
extracted.rename(profile_dir)
|
||||
|
||||
return profile_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rename
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def rename_profile(old_name: str, new_name: str) -> Path:
|
||||
"""Rename a profile: directory, wrapper script, service, active_profile.
|
||||
|
||||
Returns the new profile directory.
|
||||
"""
|
||||
validate_profile_name(old_name)
|
||||
validate_profile_name(new_name)
|
||||
|
||||
if old_name == "default":
|
||||
raise ValueError("Cannot rename the default profile.")
|
||||
if new_name == "default":
|
||||
raise ValueError("Cannot rename to 'default' — it is reserved.")
|
||||
|
||||
old_dir = get_profile_dir(old_name)
|
||||
new_dir = get_profile_dir(new_name)
|
||||
|
||||
if not old_dir.is_dir():
|
||||
raise FileNotFoundError(f"Profile '{old_name}' does not exist.")
|
||||
if new_dir.exists():
|
||||
raise FileExistsError(f"Profile '{new_name}' already exists.")
|
||||
|
||||
# 1. Stop gateway if running
|
||||
if _check_gateway_running(old_dir):
|
||||
_cleanup_gateway_service(old_name, old_dir)
|
||||
_stop_gateway_process(old_dir)
|
||||
|
||||
# 2. Rename directory
|
||||
old_dir.rename(new_dir)
|
||||
print(f"✓ Renamed {old_dir.name} → {new_dir.name}")
|
||||
|
||||
# 3. Update wrapper script
|
||||
remove_wrapper_script(old_name)
|
||||
collision = check_alias_collision(new_name)
|
||||
if not collision:
|
||||
create_wrapper_script(new_name)
|
||||
print(f"✓ Alias updated: {new_name}")
|
||||
else:
|
||||
print(f"⚠ Cannot create alias '{new_name}' — {collision}")
|
||||
|
||||
# 4. Update active_profile if it pointed to old name
|
||||
try:
|
||||
if get_active_profile() == old_name:
|
||||
set_active_profile(new_name)
|
||||
print(f"✓ Active profile updated: {new_name}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return new_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tab completion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_bash_completion() -> str:
|
||||
"""Generate a bash completion script for hermes profile names."""
|
||||
return '''# Hermes Agent profile completion
|
||||
# Add to ~/.bashrc: eval "$(hermes completion bash)"
|
||||
|
||||
_hermes_profiles() {
|
||||
local profiles_dir="$HOME/.hermes/profiles"
|
||||
local profiles="default"
|
||||
if [ -d "$profiles_dir" ]; then
|
||||
profiles="$profiles $(ls "$profiles_dir" 2>/dev/null)"
|
||||
fi
|
||||
echo "$profiles"
|
||||
}
|
||||
|
||||
_hermes_completion() {
|
||||
local cur prev
|
||||
cur="${COMP_WORDS[COMP_CWORD]}"
|
||||
prev="${COMP_WORDS[COMP_CWORD-1]}"
|
||||
|
||||
# Complete profile names after -p / --profile
|
||||
if [[ "$prev" == "-p" || "$prev" == "--profile" ]]; then
|
||||
COMPREPLY=($(compgen -W "$(_hermes_profiles)" -- "$cur"))
|
||||
return
|
||||
fi
|
||||
|
||||
# Complete profile subcommands
|
||||
if [[ "${COMP_WORDS[1]}" == "profile" ]]; then
|
||||
case "$prev" in
|
||||
profile)
|
||||
COMPREPLY=($(compgen -W "list use create delete show alias rename export import" -- "$cur"))
|
||||
return
|
||||
;;
|
||||
use|delete|show|alias|rename|export)
|
||||
COMPREPLY=($(compgen -W "$(_hermes_profiles)" -- "$cur"))
|
||||
return
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Top-level subcommands
|
||||
if [[ "$COMP_CWORD" == 1 ]]; then
|
||||
local commands="chat model gateway setup status cron doctor config skills tools mcp sessions profile update version"
|
||||
COMPREPLY=($(compgen -W "$commands" -- "$cur"))
|
||||
fi
|
||||
}
|
||||
|
||||
complete -F _hermes_completion hermes
|
||||
'''
|
||||
|
||||
|
||||
def generate_zsh_completion() -> str:
|
||||
"""Generate a zsh completion script for hermes profile names."""
|
||||
return '''#compdef hermes
|
||||
# Hermes Agent profile completion
|
||||
# Add to ~/.zshrc: eval "$(hermes completion zsh)"
|
||||
|
||||
_hermes() {
|
||||
local -a profiles
|
||||
profiles=(default)
|
||||
if [[ -d "$HOME/.hermes/profiles" ]]; then
|
||||
profiles+=("${(@f)$(ls $HOME/.hermes/profiles 2>/dev/null)}")
|
||||
fi
|
||||
|
||||
_arguments \\
|
||||
'-p[Profile name]:profile:($profiles)' \\
|
||||
'--profile[Profile name]:profile:($profiles)' \\
|
||||
'1:command:(chat model gateway setup status cron doctor config skills tools mcp sessions profile update version)' \\
|
||||
'*::arg:->args'
|
||||
|
||||
case $words[1] in
|
||||
profile)
|
||||
_arguments '1:action:(list use create delete show alias rename export import)' \\
|
||||
'2:profile:($profiles)'
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
_hermes "$@"
|
||||
'''
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Profile env resolution (called from _apply_profile_override)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def resolve_profile_env(profile_name: str) -> str:
|
||||
"""Resolve a profile name to a HERMES_HOME path string.
|
||||
|
||||
Called early in the CLI entry point, before any hermes modules
|
||||
are imported, to set the HERMES_HOME environment variable.
|
||||
"""
|
||||
validate_profile_name(profile_name)
|
||||
profile_dir = get_profile_dir(profile_name)
|
||||
|
||||
if profile_name != "default" and not profile_dir.is_dir():
|
||||
raise FileNotFoundError(
|
||||
f"Profile '{profile_name}' does not exist. "
|
||||
f"Create it with: hermes profile create {profile_name}"
|
||||
)
|
||||
|
||||
return str(profile_dir)
|
||||
+7
-3
@@ -289,6 +289,7 @@ from hermes_cli.config import (
|
||||
get_env_value,
|
||||
ensure_hermes_home,
|
||||
)
|
||||
# display_hermes_home imported lazily at call sites (stale-module safety during hermes update)
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
@@ -683,7 +684,8 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
print_warning(
|
||||
"Some tools are disabled. Run 'hermes setup tools' to configure them,"
|
||||
)
|
||||
print_warning("or edit ~/.hermes/.env directly to add the missing API keys.")
|
||||
from hermes_constants import display_hermes_home as _dhh
|
||||
print_warning(f"or edit {_dhh()}/.env directly to add the missing API keys.")
|
||||
print()
|
||||
|
||||
# Done banner
|
||||
@@ -706,7 +708,8 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
print()
|
||||
|
||||
# Show file locations prominently
|
||||
print(color("📁 All your files are in ~/.hermes/:", Colors.CYAN, Colors.BOLD))
|
||||
from hermes_constants import display_hermes_home as _dhh
|
||||
print(color(f"📁 All your files are in {_dhh()}/:", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
print(f" {color('Settings:', Colors.YELLOW)} {get_config_path()}")
|
||||
print(f" {color('API Keys:', Colors.YELLOW)} {get_env_path()}")
|
||||
@@ -2837,7 +2840,8 @@ def setup_gateway(config: dict):
|
||||
save_env_value("WEBHOOK_ENABLED", "true")
|
||||
print()
|
||||
print_success("Webhooks enabled! Next steps:")
|
||||
print_info(" 1. Define webhook routes in ~/.hermes/config.yaml")
|
||||
from hermes_constants import display_hermes_home as _dhh
|
||||
print_info(f" 1. Define webhook routes in {_dhh()}/config.yaml")
|
||||
print_info(" 2. Point your service (GitHub, GitLab, etc.) at:")
|
||||
print_info(" http://your-server:8644/webhooks/<route-name>")
|
||||
print()
|
||||
|
||||
@@ -21,6 +21,7 @@ from rich.table import Table
|
||||
|
||||
# Lazy imports to avoid circular dependencies and slow startup.
|
||||
# tools.skills_hub and tools.skills_guard are imported inside functions.
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
_console = Console()
|
||||
|
||||
@@ -388,7 +389,7 @@ def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
"[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n"
|
||||
"It ships with hermes-agent but is not activated by default.\n"
|
||||
"Installing will copy it to your skills directory where the agent can use it.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
f"Files will be at: [cyan]{display_hermes_home()}/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Official Skill",
|
||||
border_style="bright_cyan",
|
||||
))
|
||||
@@ -398,7 +399,7 @@ def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
"External skills can contain instructions that influence agent behavior,\n"
|
||||
"shell commands, and scripts. Even after automated scanning, you should\n"
|
||||
"review the installed files before use.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
f"Files will be at: [cyan]{display_hermes_home()}/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Disclaimer",
|
||||
border_style="yellow",
|
||||
))
|
||||
@@ -744,7 +745,7 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "",
|
||||
auth = GitHubAuth()
|
||||
if not auth.is_authenticated():
|
||||
c.print("[bold red]Error:[/] GitHub authentication required.\n"
|
||||
"Set GITHUB_TOKEN in ~/.hermes/.env or run 'gh auth login'.\n")
|
||||
f"Set GITHUB_TOKEN in {display_hermes_home()}/.env or run 'gh auth login'.\n")
|
||||
return
|
||||
|
||||
c.print(f"[bold]Publishing '{name}' to {repo}...[/]")
|
||||
|
||||
@@ -9,6 +9,8 @@ Saves per-platform tool configuration to ~/.hermes/config.yaml under
|
||||
the `platform_toolsets` key.
|
||||
"""
|
||||
|
||||
import json as _json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
@@ -19,6 +21,8 @@ from hermes_cli.config import (
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
@@ -326,7 +330,8 @@ def _run_post_setup(post_setup_key: str):
|
||||
if result.returncode == 0:
|
||||
_print_success(" Node.js dependencies installed")
|
||||
else:
|
||||
_print_warning(" npm install failed - run manually: cd ~/.hermes/hermes-agent && npm install")
|
||||
from hermes_constants import display_hermes_home
|
||||
_print_warning(f" npm install failed - run manually: cd {display_hermes_home()}/hermes-agent && npm install")
|
||||
elif not node_modules.exists():
|
||||
_print_warning(" Node.js not found - browser tools require: npm install (in hermes-agent directory)")
|
||||
|
||||
@@ -652,9 +657,61 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
return default
|
||||
|
||||
|
||||
# ─── Token Estimation ────────────────────────────────────────────────────────
|
||||
|
||||
# Module-level cache so discovery + tokenization runs at most once per process.
|
||||
_tool_token_cache: Optional[Dict[str, int]] = None
|
||||
|
||||
|
||||
def _estimate_tool_tokens() -> Dict[str, int]:
|
||||
"""Return estimated token counts per individual tool name.
|
||||
|
||||
Uses tiktoken (cl100k_base) to count tokens in the JSON-serialised
|
||||
OpenAI-format tool schema. Triggers tool discovery on first call,
|
||||
then caches the result for the rest of the process.
|
||||
|
||||
Returns an empty dict when tiktoken or the registry is unavailable.
|
||||
"""
|
||||
global _tool_token_cache
|
||||
if _tool_token_cache is not None:
|
||||
return _tool_token_cache
|
||||
|
||||
try:
|
||||
import tiktoken
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
except Exception:
|
||||
logger.debug("tiktoken unavailable; skipping tool token estimation")
|
||||
_tool_token_cache = {}
|
||||
return _tool_token_cache
|
||||
|
||||
try:
|
||||
# Trigger full tool discovery (imports all tool modules).
|
||||
import model_tools # noqa: F401
|
||||
from tools.registry import registry
|
||||
except Exception:
|
||||
logger.debug("Tool registry unavailable; skipping token estimation")
|
||||
_tool_token_cache = {}
|
||||
return _tool_token_cache
|
||||
|
||||
counts: Dict[str, int] = {}
|
||||
for name in registry.get_all_tool_names():
|
||||
schema = registry.get_schema(name)
|
||||
if schema:
|
||||
# Mirror what gets sent to the API:
|
||||
# {"type": "function", "function": <schema>}
|
||||
text = _json.dumps({"type": "function", "function": schema})
|
||||
counts[name] = len(enc.encode(text))
|
||||
_tool_token_cache = counts
|
||||
return _tool_token_cache
|
||||
|
||||
|
||||
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
|
||||
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
# Pre-compute per-tool token counts (cached after first call).
|
||||
tool_tokens = _estimate_tool_tokens()
|
||||
|
||||
effective = _get_effective_configurable_toolsets()
|
||||
|
||||
@@ -670,11 +727,27 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
if ts_key in enabled
|
||||
}
|
||||
|
||||
# Build a live status function that shows deduplicated total token cost.
|
||||
status_fn = None
|
||||
if tool_tokens:
|
||||
ts_keys = [ts_key for ts_key, _, _ in effective]
|
||||
|
||||
def status_fn(chosen: set) -> str:
|
||||
# Collect unique tool names across all selected toolsets
|
||||
all_tools: set = set()
|
||||
for idx in chosen:
|
||||
all_tools.update(resolve_toolset(ts_keys[idx]))
|
||||
total = sum(tool_tokens.get(name, 0) for name in all_tools)
|
||||
if total >= 1000:
|
||||
return f"Est. tool context: ~{total / 1000:.1f}k tokens"
|
||||
return f"Est. tool context: ~{total} tokens"
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"Tools for {platform_label}",
|
||||
labels,
|
||||
pre_selected,
|
||||
cancel_returns=pre_selected,
|
||||
status_fn=status_fn,
|
||||
)
|
||||
return {effective[i][0] for i in chosen}
|
||||
|
||||
@@ -1264,7 +1337,8 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
platform_choices[idx] = f"Configure {pinfo['label']} ({new_count}/{total} enabled)"
|
||||
|
||||
print()
|
||||
print(color(" Tool configuration saved to ~/.hermes/config.yaml", Colors.DIM))
|
||||
from hermes_constants import display_hermes_home
|
||||
print(color(f" Tool configuration saved to {display_hermes_home()}/config.yaml", Colors.DIM))
|
||||
print(color(" Changes take effect on next 'hermes' or gateway restart.", Colors.DIM))
|
||||
print()
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
|
||||
_SUBSCRIPTIONS_FILENAME = "webhook_subscriptions.json"
|
||||
|
||||
@@ -76,13 +78,15 @@ def _get_webhook_base_url() -> str:
|
||||
return f"http://{display_host}:{port}"
|
||||
|
||||
|
||||
_SETUP_HINT = """
|
||||
def _setup_hint() -> str:
|
||||
_dhh = display_hermes_home()
|
||||
return f"""
|
||||
Webhook platform is not enabled. To set it up:
|
||||
|
||||
1. Run the gateway setup wizard:
|
||||
hermes gateway setup
|
||||
|
||||
2. Or manually add to ~/.hermes/config.yaml:
|
||||
2. Or manually add to {_dhh}/config.yaml:
|
||||
platforms:
|
||||
webhook:
|
||||
enabled: true
|
||||
@@ -91,7 +95,7 @@ _SETUP_HINT = """
|
||||
port: 8644
|
||||
secret: "your-global-hmac-secret"
|
||||
|
||||
3. Or set environment variables in ~/.hermes/.env:
|
||||
3. Or set environment variables in {_dhh}/.env:
|
||||
WEBHOOK_ENABLED=true
|
||||
WEBHOOK_PORT=8644
|
||||
WEBHOOK_SECRET=your-global-secret
|
||||
@@ -104,7 +108,7 @@ def _require_webhook_enabled() -> bool:
|
||||
"""Check webhook is enabled. Print setup guide and return False if not."""
|
||||
if _is_webhook_enabled():
|
||||
return True
|
||||
print(_SETUP_HINT)
|
||||
print(_setup_hint())
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,26 @@ def get_hermes_dir(new_subpath: str, old_name: str) -> Path:
|
||||
return home / new_subpath
|
||||
|
||||
|
||||
def display_hermes_home() -> str:
|
||||
"""Return a user-friendly display string for the current HERMES_HOME.
|
||||
|
||||
Uses ``~/`` shorthand for readability::
|
||||
|
||||
default: ``~/.hermes``
|
||||
profile: ``~/.hermes/profiles/coder``
|
||||
custom: ``/opt/hermes-custom``
|
||||
|
||||
Use this in **user-facing** print/log messages instead of hardcoding
|
||||
``~/.hermes``. For code that needs a real ``Path``, use
|
||||
:func:`get_hermes_home` instead.
|
||||
"""
|
||||
home = get_hermes_home()
|
||||
try:
|
||||
return "~/" + str(home.relative_to(Path.home()))
|
||||
except ValueError:
|
||||
return str(home)
|
||||
|
||||
|
||||
VALID_REASONING_EFFORTS = ("xhigh", "high", "medium", "low", "minimal")
|
||||
|
||||
|
||||
|
||||
+868
@@ -0,0 +1,868 @@
|
||||
"""
|
||||
Hermes MCP Server — expose messaging conversations as MCP tools.
|
||||
|
||||
Starts a stdio MCP server that lets any MCP client (Claude Code, Cursor, Codex,
|
||||
etc.) list conversations, read message history, send messages, poll for live
|
||||
events, and manage approval requests across all connected platforms.
|
||||
|
||||
Matches OpenClaw's 9-tool MCP channel bridge surface:
|
||||
conversations_list, conversation_get, messages_read, attachments_fetch,
|
||||
events_poll, events_wait, messages_send, permissions_list_open,
|
||||
permissions_respond
|
||||
|
||||
Plus: channels_list (Hermes-specific extra)
|
||||
|
||||
Usage:
|
||||
hermes mcp serve
|
||||
hermes mcp serve --verbose
|
||||
|
||||
MCP client config (e.g. claude_desktop_config.json):
|
||||
{
|
||||
"mcpServers": {
|
||||
"hermes": {
|
||||
"command": "hermes",
|
||||
"args": ["mcp", "serve"]
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger("hermes.mcp_serve")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy MCP SDK import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MCP_SERVER_AVAILABLE = False
|
||||
try:
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
_MCP_SERVER_AVAILABLE = True
|
||||
except ImportError:
|
||||
FastMCP = None # type: ignore[assignment,misc]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_sessions_dir() -> Path:
|
||||
"""Return the sessions directory using HERMES_HOME."""
|
||||
try:
|
||||
from hermes_constants import get_hermes_home
|
||||
return get_hermes_home() / "sessions"
|
||||
except ImportError:
|
||||
return Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) / "sessions"
|
||||
|
||||
|
||||
def _get_session_db():
|
||||
"""Get a SessionDB instance for reading message transcripts."""
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
return SessionDB()
|
||||
except Exception as e:
|
||||
logger.debug("SessionDB unavailable: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _load_sessions_index() -> dict:
|
||||
"""Load the gateway sessions.json index directly.
|
||||
|
||||
Returns a dict of session_key -> entry_dict with platform routing info.
|
||||
This avoids importing the full SessionStore which needs GatewayConfig.
|
||||
"""
|
||||
sessions_file = _get_sessions_dir() / "sessions.json"
|
||||
if not sessions_file.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(sessions_file, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to load sessions.json: %s", e)
|
||||
return {}
|
||||
|
||||
|
||||
def _load_channel_directory() -> dict:
|
||||
"""Load the cached channel directory for available targets."""
|
||||
try:
|
||||
from hermes_constants import get_hermes_home
|
||||
directory_file = get_hermes_home() / "channel_directory.json"
|
||||
except ImportError:
|
||||
directory_file = Path(
|
||||
os.environ.get("HERMES_HOME", Path.home() / ".hermes")
|
||||
) / "channel_directory.json"
|
||||
|
||||
if not directory_file.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(directory_file, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to load channel_directory.json: %s", e)
|
||||
return {}
|
||||
|
||||
|
||||
def _extract_message_content(msg: dict) -> str:
|
||||
"""Extract text content from a message, handling multi-part content."""
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
text_parts = [
|
||||
p.get("text", "") for p in content
|
||||
if isinstance(p, dict) and p.get("type") == "text"
|
||||
]
|
||||
return "\n".join(text_parts)
|
||||
return str(content) if content else ""
|
||||
|
||||
|
||||
def _extract_attachments(msg: dict) -> List[dict]:
|
||||
"""Extract non-text attachments from a message.
|
||||
|
||||
Finds: multi-part image/file content blocks, MEDIA: tags in text,
|
||||
image URLs, and file references.
|
||||
"""
|
||||
attachments = []
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Multi-part content blocks (image_url, file, etc.)
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
ptype = part.get("type", "")
|
||||
if ptype == "image_url":
|
||||
url = part.get("image_url", {}).get("url", "") if isinstance(part.get("image_url"), dict) else ""
|
||||
if url:
|
||||
attachments.append({"type": "image", "url": url})
|
||||
elif ptype == "image":
|
||||
url = part.get("url", part.get("source", {}).get("url", ""))
|
||||
if url:
|
||||
attachments.append({"type": "image", "url": url})
|
||||
elif ptype not in ("text",):
|
||||
# Unknown non-text content type
|
||||
attachments.append({"type": ptype, "data": part})
|
||||
|
||||
# MEDIA: tags in text content
|
||||
text = _extract_message_content(msg)
|
||||
if text:
|
||||
media_pattern = re.compile(r'MEDIA:\s*(\S+)')
|
||||
for match in media_pattern.finditer(text):
|
||||
path = match.group(1)
|
||||
attachments.append({"type": "media", "path": path})
|
||||
|
||||
return attachments
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event Bridge — polls SessionDB for new messages, maintains event queue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
QUEUE_LIMIT = 1000
|
||||
POLL_INTERVAL = 0.2 # seconds between DB polls (200ms)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueEvent:
|
||||
"""An event in the bridge's in-memory queue."""
|
||||
cursor: int
|
||||
type: str # "message", "approval_requested", "approval_resolved"
|
||||
session_key: str = ""
|
||||
data: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class EventBridge:
|
||||
"""Background poller that watches SessionDB for new messages and
|
||||
maintains an in-memory event queue with waiter support.
|
||||
|
||||
This is the Hermes equivalent of OpenClaw's WebSocket gateway bridge.
|
||||
Instead of WebSocket events, we poll the SQLite database for changes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._queue: List[QueueEvent] = []
|
||||
self._cursor = 0
|
||||
self._lock = threading.Lock()
|
||||
self._new_event = threading.Event()
|
||||
self._running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._last_poll_timestamps: Dict[str, float] = {} # session_key -> unix timestamp
|
||||
# In-memory approval tracking (populated from events)
|
||||
self._pending_approvals: Dict[str, dict] = {}
|
||||
# mtime cache — skip expensive work when files haven't changed
|
||||
self._sessions_json_mtime: float = 0.0
|
||||
self._state_db_mtime: float = 0.0
|
||||
self._cached_sessions_index: dict = {}
|
||||
|
||||
def start(self):
|
||||
"""Start the background polling thread."""
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._poll_loop, daemon=True)
|
||||
self._thread.start()
|
||||
logger.debug("EventBridge started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background polling thread."""
|
||||
self._running = False
|
||||
self._new_event.set() # Wake any waiters
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
logger.debug("EventBridge stopped")
|
||||
|
||||
def poll_events(
|
||||
self,
|
||||
after_cursor: int = 0,
|
||||
session_key: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
) -> dict:
|
||||
"""Return events since after_cursor, optionally filtered by session_key."""
|
||||
with self._lock:
|
||||
events = [
|
||||
e for e in self._queue
|
||||
if e.cursor > after_cursor
|
||||
and (not session_key or e.session_key == session_key)
|
||||
][:limit]
|
||||
|
||||
next_cursor = events[-1].cursor if events else after_cursor
|
||||
return {
|
||||
"events": [
|
||||
{"cursor": e.cursor, "type": e.type,
|
||||
"session_key": e.session_key, **e.data}
|
||||
for e in events
|
||||
],
|
||||
"next_cursor": next_cursor,
|
||||
}
|
||||
|
||||
def wait_for_event(
|
||||
self,
|
||||
after_cursor: int = 0,
|
||||
session_key: Optional[str] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Optional[dict]:
|
||||
"""Block until a matching event arrives or timeout expires."""
|
||||
deadline = time.monotonic() + (timeout_ms / 1000.0)
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
with self._lock:
|
||||
for e in self._queue:
|
||||
if e.cursor > after_cursor and (
|
||||
not session_key or e.session_key == session_key
|
||||
):
|
||||
return {
|
||||
"cursor": e.cursor, "type": e.type,
|
||||
"session_key": e.session_key, **e.data,
|
||||
}
|
||||
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
self._new_event.clear()
|
||||
self._new_event.wait(timeout=min(remaining, POLL_INTERVAL))
|
||||
|
||||
return None
|
||||
|
||||
def list_pending_approvals(self) -> List[dict]:
|
||||
"""List approval requests observed during this bridge session."""
|
||||
with self._lock:
|
||||
return sorted(
|
||||
self._pending_approvals.values(),
|
||||
key=lambda a: a.get("created_at", ""),
|
||||
)
|
||||
|
||||
def respond_to_approval(self, approval_id: str, decision: str) -> dict:
|
||||
"""Resolve a pending approval (best-effort without gateway IPC)."""
|
||||
with self._lock:
|
||||
approval = self._pending_approvals.pop(approval_id, None)
|
||||
|
||||
if not approval:
|
||||
return {"error": f"Approval not found: {approval_id}"}
|
||||
|
||||
self._enqueue(QueueEvent(
|
||||
cursor=0, # Will be set by _enqueue
|
||||
type="approval_resolved",
|
||||
session_key=approval.get("session_key", ""),
|
||||
data={"approval_id": approval_id, "decision": decision},
|
||||
))
|
||||
|
||||
return {"resolved": True, "approval_id": approval_id, "decision": decision}
|
||||
|
||||
def _enqueue(self, event: QueueEvent) -> None:
|
||||
"""Add an event to the queue and wake any waiters."""
|
||||
with self._lock:
|
||||
self._cursor += 1
|
||||
event.cursor = self._cursor
|
||||
self._queue.append(event)
|
||||
# Trim queue to limit
|
||||
while len(self._queue) > QUEUE_LIMIT:
|
||||
self._queue.pop(0)
|
||||
self._new_event.set()
|
||||
|
||||
def _poll_loop(self):
|
||||
"""Background loop: poll SessionDB for new messages."""
|
||||
db = _get_session_db()
|
||||
if not db:
|
||||
logger.warning("EventBridge: SessionDB unavailable, event polling disabled")
|
||||
return
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
self._poll_once(db)
|
||||
except Exception as e:
|
||||
logger.debug("EventBridge poll error: %s", e)
|
||||
time.sleep(POLL_INTERVAL)
|
||||
|
||||
def _poll_once(self, db):
|
||||
"""Check for new messages across all sessions.
|
||||
|
||||
Uses mtime checks on sessions.json and state.db to skip work
|
||||
when nothing has changed — makes 200ms polling essentially free.
|
||||
"""
|
||||
# Check if sessions.json has changed (mtime check is ~1μs)
|
||||
sessions_file = _get_sessions_dir() / "sessions.json"
|
||||
try:
|
||||
sj_mtime = sessions_file.stat().st_mtime if sessions_file.exists() else 0.0
|
||||
except OSError:
|
||||
sj_mtime = 0.0
|
||||
|
||||
if sj_mtime != self._sessions_json_mtime:
|
||||
self._sessions_json_mtime = sj_mtime
|
||||
self._cached_sessions_index = _load_sessions_index()
|
||||
|
||||
# Check if state.db has changed
|
||||
try:
|
||||
from hermes_constants import get_hermes_home
|
||||
db_file = get_hermes_home() / "state.db"
|
||||
except ImportError:
|
||||
db_file = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
|
||||
|
||||
try:
|
||||
db_mtime = db_file.stat().st_mtime if db_file.exists() else 0.0
|
||||
except OSError:
|
||||
db_mtime = 0.0
|
||||
|
||||
if db_mtime == self._state_db_mtime and sj_mtime == self._sessions_json_mtime:
|
||||
return # Nothing changed since last poll — skip entirely
|
||||
|
||||
self._state_db_mtime = db_mtime
|
||||
entries = self._cached_sessions_index
|
||||
|
||||
for session_key, entry in entries.items():
|
||||
session_id = entry.get("session_id", "")
|
||||
if not session_id:
|
||||
continue
|
||||
|
||||
last_seen = self._last_poll_timestamps.get(session_key, 0.0)
|
||||
|
||||
try:
|
||||
messages = db.get_messages(session_id)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
# Normalize timestamps to float for comparison
|
||||
def _ts_float(ts) -> float:
|
||||
if isinstance(ts, (int, float)):
|
||||
return float(ts)
|
||||
if isinstance(ts, str) and ts:
|
||||
try:
|
||||
return float(ts)
|
||||
except ValueError:
|
||||
# ISO string — parse to epoch
|
||||
try:
|
||||
from datetime import datetime
|
||||
return datetime.fromisoformat(ts).timestamp()
|
||||
except Exception:
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
# Find messages newer than our last seen timestamp
|
||||
new_messages = []
|
||||
for msg in messages:
|
||||
ts = _ts_float(msg.get("timestamp", 0))
|
||||
role = msg.get("role", "")
|
||||
if role not in ("user", "assistant"):
|
||||
continue
|
||||
if ts > last_seen:
|
||||
new_messages.append(msg)
|
||||
|
||||
for msg in new_messages:
|
||||
content = _extract_message_content(msg)
|
||||
if not content:
|
||||
continue
|
||||
self._enqueue(QueueEvent(
|
||||
cursor=0,
|
||||
type="message",
|
||||
session_key=session_key,
|
||||
data={
|
||||
"role": msg.get("role", ""),
|
||||
"content": content[:500],
|
||||
"timestamp": str(msg.get("timestamp", "")),
|
||||
"message_id": str(msg.get("id", "")),
|
||||
},
|
||||
))
|
||||
|
||||
# Update last seen to the most recent message timestamp
|
||||
all_ts = [_ts_float(m.get("timestamp", 0)) for m in messages]
|
||||
if all_ts:
|
||||
latest = max(all_ts)
|
||||
if latest > last_seen:
|
||||
self._last_poll_timestamps[session_key] = latest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP Server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP":
|
||||
"""Create and return the Hermes MCP server with all tools registered."""
|
||||
if not _MCP_SERVER_AVAILABLE:
|
||||
raise ImportError(
|
||||
"MCP server requires the 'mcp' package. "
|
||||
"Install with: pip install 'hermes-agent[mcp]'"
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
"hermes",
|
||||
instructions=(
|
||||
"Hermes Agent messaging bridge. Use these tools to interact with "
|
||||
"conversations across Telegram, Discord, Slack, WhatsApp, Signal, "
|
||||
"Matrix, and other connected platforms."
|
||||
),
|
||||
)
|
||||
|
||||
bridge = event_bridge or EventBridge()
|
||||
|
||||
# -- conversations_list ------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def conversations_list(
|
||||
platform: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
search: Optional[str] = None,
|
||||
) -> str:
|
||||
"""List active messaging conversations across connected platforms.
|
||||
|
||||
Returns conversations with their session keys (needed for messages_read),
|
||||
platform, chat type, display name, and last activity time.
|
||||
|
||||
Args:
|
||||
platform: Filter by platform name (telegram, discord, slack, etc.)
|
||||
limit: Maximum number of conversations to return (default 50)
|
||||
search: Optional text to filter conversations by name
|
||||
"""
|
||||
entries = _load_sessions_index()
|
||||
conversations = []
|
||||
|
||||
for key, entry in entries.items():
|
||||
origin = entry.get("origin", {})
|
||||
entry_platform = entry.get("platform") or origin.get("platform", "")
|
||||
|
||||
if platform and entry_platform.lower() != platform.lower():
|
||||
continue
|
||||
|
||||
display_name = entry.get("display_name", "")
|
||||
chat_name = origin.get("chat_name", "")
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
if (search_lower not in display_name.lower()
|
||||
and search_lower not in chat_name.lower()
|
||||
and search_lower not in key.lower()):
|
||||
continue
|
||||
|
||||
conversations.append({
|
||||
"session_key": key,
|
||||
"session_id": entry.get("session_id", ""),
|
||||
"platform": entry_platform,
|
||||
"chat_type": entry.get("chat_type", origin.get("chat_type", "")),
|
||||
"display_name": display_name,
|
||||
"chat_name": chat_name,
|
||||
"user_name": origin.get("user_name", ""),
|
||||
"updated_at": entry.get("updated_at", ""),
|
||||
})
|
||||
|
||||
conversations.sort(key=lambda c: c.get("updated_at", ""), reverse=True)
|
||||
conversations = conversations[:limit]
|
||||
|
||||
return json.dumps({
|
||||
"count": len(conversations),
|
||||
"conversations": conversations,
|
||||
}, indent=2)
|
||||
|
||||
# -- conversation_get --------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def conversation_get(session_key: str) -> str:
|
||||
"""Get detailed info about one conversation by its session key.
|
||||
|
||||
Args:
|
||||
session_key: The session key from conversations_list
|
||||
"""
|
||||
entries = _load_sessions_index()
|
||||
entry = entries.get(session_key)
|
||||
|
||||
if not entry:
|
||||
return json.dumps({"error": f"Conversation not found: {session_key}"})
|
||||
|
||||
origin = entry.get("origin", {})
|
||||
return json.dumps({
|
||||
"session_key": session_key,
|
||||
"session_id": entry.get("session_id", ""),
|
||||
"platform": entry.get("platform") or origin.get("platform", ""),
|
||||
"chat_type": entry.get("chat_type", origin.get("chat_type", "")),
|
||||
"display_name": entry.get("display_name", ""),
|
||||
"user_name": origin.get("user_name", ""),
|
||||
"chat_name": origin.get("chat_name", ""),
|
||||
"chat_id": origin.get("chat_id", ""),
|
||||
"thread_id": origin.get("thread_id"),
|
||||
"updated_at": entry.get("updated_at", ""),
|
||||
"created_at": entry.get("created_at", ""),
|
||||
"input_tokens": entry.get("input_tokens", 0),
|
||||
"output_tokens": entry.get("output_tokens", 0),
|
||||
"total_tokens": entry.get("total_tokens", 0),
|
||||
}, indent=2)
|
||||
|
||||
# -- messages_read -----------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def messages_read(
|
||||
session_key: str,
|
||||
limit: int = 50,
|
||||
) -> str:
|
||||
"""Read recent messages from a conversation.
|
||||
|
||||
Returns the message history in chronological order with role, content,
|
||||
and timestamp for each message.
|
||||
|
||||
Args:
|
||||
session_key: The session key from conversations_list
|
||||
limit: Maximum number of messages to return (default 50, most recent)
|
||||
"""
|
||||
entries = _load_sessions_index()
|
||||
entry = entries.get(session_key)
|
||||
if not entry:
|
||||
return json.dumps({"error": f"Conversation not found: {session_key}"})
|
||||
|
||||
session_id = entry.get("session_id", "")
|
||||
if not session_id:
|
||||
return json.dumps({"error": "No session ID for this conversation"})
|
||||
|
||||
db = _get_session_db()
|
||||
if not db:
|
||||
return json.dumps({"error": "Session database unavailable"})
|
||||
|
||||
try:
|
||||
all_messages = db.get_messages(session_id)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to read messages: {e}"})
|
||||
|
||||
filtered = []
|
||||
for msg in all_messages:
|
||||
role = msg.get("role", "")
|
||||
if role in ("user", "assistant"):
|
||||
content = _extract_message_content(msg)
|
||||
if content:
|
||||
filtered.append({
|
||||
"id": str(msg.get("id", "")),
|
||||
"role": role,
|
||||
"content": content[:2000],
|
||||
"timestamp": msg.get("timestamp", ""),
|
||||
})
|
||||
|
||||
messages = filtered[-limit:]
|
||||
|
||||
return json.dumps({
|
||||
"session_key": session_key,
|
||||
"count": len(messages),
|
||||
"total_in_session": len(filtered),
|
||||
"messages": messages,
|
||||
}, indent=2)
|
||||
|
||||
# -- attachments_fetch -------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def attachments_fetch(
|
||||
session_key: str,
|
||||
message_id: str,
|
||||
) -> str:
|
||||
"""List non-text attachments for a message in a conversation.
|
||||
|
||||
Extracts images, media files, and other non-text content blocks
|
||||
from the specified message.
|
||||
|
||||
Args:
|
||||
session_key: The session key from conversations_list
|
||||
message_id: The message ID from messages_read
|
||||
"""
|
||||
entries = _load_sessions_index()
|
||||
entry = entries.get(session_key)
|
||||
if not entry:
|
||||
return json.dumps({"error": f"Conversation not found: {session_key}"})
|
||||
|
||||
session_id = entry.get("session_id", "")
|
||||
if not session_id:
|
||||
return json.dumps({"error": "No session ID for this conversation"})
|
||||
|
||||
db = _get_session_db()
|
||||
if not db:
|
||||
return json.dumps({"error": "Session database unavailable"})
|
||||
|
||||
try:
|
||||
all_messages = db.get_messages(session_id)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to read messages: {e}"})
|
||||
|
||||
# Find the target message
|
||||
target_msg = None
|
||||
for msg in all_messages:
|
||||
if str(msg.get("id", "")) == message_id:
|
||||
target_msg = msg
|
||||
break
|
||||
|
||||
if not target_msg:
|
||||
return json.dumps({"error": f"Message not found: {message_id}"})
|
||||
|
||||
attachments = _extract_attachments(target_msg)
|
||||
|
||||
return json.dumps({
|
||||
"message_id": message_id,
|
||||
"count": len(attachments),
|
||||
"attachments": attachments,
|
||||
}, indent=2)
|
||||
|
||||
# -- events_poll -------------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def events_poll(
|
||||
after_cursor: int = 0,
|
||||
session_key: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
) -> str:
|
||||
"""Poll for new conversation events since a cursor position.
|
||||
|
||||
Returns events that have occurred since the given cursor. Use the
|
||||
returned next_cursor value for subsequent polls.
|
||||
|
||||
Event types: message, approval_requested, approval_resolved
|
||||
|
||||
Args:
|
||||
after_cursor: Return events after this cursor (0 for all)
|
||||
session_key: Optional filter to one conversation
|
||||
limit: Maximum events to return (default 20)
|
||||
"""
|
||||
result = bridge.poll_events(
|
||||
after_cursor=after_cursor,
|
||||
session_key=session_key,
|
||||
limit=limit,
|
||||
)
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
# -- events_wait -------------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def events_wait(
|
||||
after_cursor: int = 0,
|
||||
session_key: Optional[str] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> str:
|
||||
"""Wait for the next conversation event (long-poll).
|
||||
|
||||
Blocks until a matching event arrives or the timeout expires.
|
||||
Use this for near-real-time event delivery without polling.
|
||||
|
||||
Args:
|
||||
after_cursor: Wait for events after this cursor
|
||||
session_key: Optional filter to one conversation
|
||||
timeout_ms: Maximum wait time in milliseconds (default 30000)
|
||||
"""
|
||||
event = bridge.wait_for_event(
|
||||
after_cursor=after_cursor,
|
||||
session_key=session_key,
|
||||
timeout_ms=min(timeout_ms, 300000), # Cap at 5 minutes
|
||||
)
|
||||
if event:
|
||||
return json.dumps({"event": event}, indent=2)
|
||||
return json.dumps({"event": None, "reason": "timeout"}, indent=2)
|
||||
|
||||
# -- messages_send -----------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def messages_send(
|
||||
target: str,
|
||||
message: str,
|
||||
) -> str:
|
||||
"""Send a message to a platform conversation.
|
||||
|
||||
The target format is "platform:chat_id" — same format used by the
|
||||
channels_list tool. You can also use human-friendly channel names
|
||||
that will be resolved automatically.
|
||||
|
||||
Examples:
|
||||
target="telegram:6308981865"
|
||||
target="discord:#general"
|
||||
target="slack:#engineering"
|
||||
|
||||
Args:
|
||||
target: Platform target in "platform:identifier" format
|
||||
message: The message text to send
|
||||
"""
|
||||
if not target or not message:
|
||||
return json.dumps({"error": "Both target and message are required"})
|
||||
|
||||
try:
|
||||
from tools.send_message_tool import send_message_tool
|
||||
result_str = send_message_tool(
|
||||
{"action": "send", "target": target, "message": message}
|
||||
)
|
||||
return result_str
|
||||
except ImportError:
|
||||
return json.dumps({"error": "Send message tool not available"})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Send failed: {e}"})
|
||||
|
||||
# -- channels_list -----------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def channels_list(platform: Optional[str] = None) -> str:
|
||||
"""List available messaging channels and targets across platforms.
|
||||
|
||||
Returns channels that you can send messages to. The target strings
|
||||
returned here can be used directly with the messages_send tool.
|
||||
|
||||
Args:
|
||||
platform: Filter by platform name (telegram, discord, slack, etc.)
|
||||
"""
|
||||
directory = _load_channel_directory()
|
||||
if not directory:
|
||||
entries = _load_sessions_index()
|
||||
targets = []
|
||||
seen = set()
|
||||
for key, entry in entries.items():
|
||||
origin = entry.get("origin", {})
|
||||
p = entry.get("platform") or origin.get("platform", "")
|
||||
chat_id = origin.get("chat_id", "")
|
||||
if not p or not chat_id:
|
||||
continue
|
||||
if platform and p.lower() != platform.lower():
|
||||
continue
|
||||
target_str = f"{p}:{chat_id}"
|
||||
if target_str in seen:
|
||||
continue
|
||||
seen.add(target_str)
|
||||
targets.append({
|
||||
"target": target_str,
|
||||
"platform": p,
|
||||
"name": entry.get("display_name") or origin.get("chat_name", ""),
|
||||
"chat_type": entry.get("chat_type", origin.get("chat_type", "")),
|
||||
})
|
||||
return json.dumps({"count": len(targets), "channels": targets}, indent=2)
|
||||
|
||||
channels = []
|
||||
for plat, entries_list in directory.items():
|
||||
if platform and plat.lower() != platform.lower():
|
||||
continue
|
||||
if isinstance(entries_list, list):
|
||||
for ch in entries_list:
|
||||
if isinstance(ch, dict):
|
||||
chat_id = ch.get("id", ch.get("chat_id", ""))
|
||||
channels.append({
|
||||
"target": f"{plat}:{chat_id}" if chat_id else plat,
|
||||
"platform": plat,
|
||||
"name": ch.get("name", ch.get("display_name", "")),
|
||||
"chat_type": ch.get("type", ""),
|
||||
})
|
||||
|
||||
return json.dumps({"count": len(channels), "channels": channels}, indent=2)
|
||||
|
||||
# -- permissions_list_open ---------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def permissions_list_open() -> str:
|
||||
"""List pending approval requests observed during this bridge session.
|
||||
|
||||
Returns exec and plugin approval requests that the bridge has seen
|
||||
since it started. Approvals are live-session only — older approvals
|
||||
from before the bridge connected are not included.
|
||||
"""
|
||||
approvals = bridge.list_pending_approvals()
|
||||
return json.dumps({
|
||||
"count": len(approvals),
|
||||
"approvals": approvals,
|
||||
}, indent=2)
|
||||
|
||||
# -- permissions_respond -----------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def permissions_respond(
|
||||
id: str,
|
||||
decision: str,
|
||||
) -> str:
|
||||
"""Respond to a pending approval request.
|
||||
|
||||
Args:
|
||||
id: The approval ID from permissions_list_open
|
||||
decision: One of "allow-once", "allow-always", or "deny"
|
||||
"""
|
||||
if decision not in ("allow-once", "allow-always", "deny"):
|
||||
return json.dumps({
|
||||
"error": f"Invalid decision: {decision}. "
|
||||
f"Must be allow-once, allow-always, or deny"
|
||||
})
|
||||
|
||||
result = bridge.respond_to_approval(id, decision)
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
return mcp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_mcp_server(verbose: bool = False) -> None:
|
||||
"""Start the Hermes MCP server on stdio."""
|
||||
if not _MCP_SERVER_AVAILABLE:
|
||||
print(
|
||||
"Error: MCP server requires the 'mcp' package.\n"
|
||||
"Install with: pip install 'hermes-agent[mcp]'",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if verbose:
|
||||
logging.basicConfig(level=logging.DEBUG, stream=sys.stderr)
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARNING, stream=sys.stderr)
|
||||
|
||||
bridge = EventBridge()
|
||||
bridge.start()
|
||||
|
||||
server = create_mcp_server(event_bridge=bridge)
|
||||
|
||||
import asyncio
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
await server.run_stdio_async()
|
||||
finally:
|
||||
bridge.stop()
|
||||
|
||||
try:
|
||||
asyncio.run(_run())
|
||||
except KeyboardInterrupt:
|
||||
bridge.stop()
|
||||
+32
-3
@@ -21,6 +21,7 @@ Public API (signatures preserved from the original 2,400-line version):
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
@@ -365,6 +366,33 @@ _AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
|
||||
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
|
||||
|
||||
|
||||
def _sanitize_tool_error(error_msg: str) -> str:
|
||||
"""Sanitize tool error messages before sending to the LLM.
|
||||
|
||||
- Strips XML/JSON boundary markers that could confuse the model
|
||||
- Truncates to 2000 chars max
|
||||
- Wraps in a clear error format so the LLM knows it's an error
|
||||
"""
|
||||
sanitized = error_msg
|
||||
# Strip XML-like tags that could confuse the LLM (role / framing tags)
|
||||
sanitized = re.sub(
|
||||
r'</?(?:tool_call|function_call|result|response|output|input|system|assistant|user)>',
|
||||
'', sanitized,
|
||||
)
|
||||
# Strip markdown code fences (opening and closing)
|
||||
sanitized = re.sub(r'^\s*```(?:json|xml)?\s*', '', sanitized)
|
||||
sanitized = re.sub(r'\s*```\s*$', '', sanitized)
|
||||
# Remove CDATA sections
|
||||
sanitized = re.sub(r'<!\[CDATA\[.*?\]\]>', '', sanitized, flags=re.DOTALL)
|
||||
|
||||
# Truncate very long error messages
|
||||
if len(sanitized) > 2000:
|
||||
sanitized = sanitized[:1997] + '...'
|
||||
|
||||
# Wrap in clear error format
|
||||
return f"[TOOL_ERROR] {sanitized}"
|
||||
|
||||
|
||||
def handle_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
@@ -438,9 +466,10 @@ def handle_function_call(
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing {function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return json.dumps({"error": error_msg}, ensure_ascii=False)
|
||||
raw_error = f"Error executing {function_name}: {str(e)}"
|
||||
logger.error(raw_error)
|
||||
sanitized = _sanitize_tool_error(raw_error)
|
||||
return json.dumps({"error": sanitized}, ensure_ascii=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Communication and decision-making frameworks — structured response formats for proposals, trade-off analysis, and stakeholder-ready recommendations.
|
||||
@@ -0,0 +1,103 @@
|
||||
---
|
||||
name: one-three-one-rule
|
||||
description: >
|
||||
Structured decision-making framework for technical proposals and trade-off analysis.
|
||||
When the user faces a choice between multiple approaches (architecture decisions,
|
||||
tool selection, refactoring strategies, migration paths), this skill produces a
|
||||
1-3-1 format: one clear problem statement, three distinct options with pros/cons,
|
||||
and one concrete recommendation with definition of done and implementation plan.
|
||||
Use when the user asks for a "1-3-1", says "give me options", or needs help
|
||||
choosing between competing approaches.
|
||||
version: 1.0.0
|
||||
author: Willard Moore
|
||||
license: MIT
|
||||
category: communication
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [communication, decision-making, proposals, trade-offs]
|
||||
---
|
||||
|
||||
# 1-3-1 Communication Rule
|
||||
|
||||
Structured decision-making format for when a task has multiple viable approaches and the user needs a clear recommendation. Produces a concise problem framing, three options with trade-offs, and an actionable plan for the recommended path.
|
||||
|
||||
## When to Use
|
||||
|
||||
- The user explicitly asks for a "1-3-1" response.
|
||||
- The user says "give me options" or "what are my choices" for a technical decision.
|
||||
- A task has multiple viable approaches with meaningful trade-offs (architecture, tooling, migration strategy).
|
||||
- The user needs a proposal they can forward to a team or stakeholder.
|
||||
|
||||
Do NOT use for simple questions with one obvious answer, debugging sessions, or tasks where the user has already decided on an approach.
|
||||
|
||||
## Procedure
|
||||
|
||||
1. **Problem** (one sentence)
|
||||
- State the core decision or desired outcome in a single concise sentence.
|
||||
- Focus on the *what*, not the *how* — no implementation details, tool names, or specific technologies.
|
||||
- Keep it tight. If you need "and", you're describing two problems.
|
||||
|
||||
2. **Options** (exactly three)
|
||||
- Present three distinct, viable approaches labeled A, B, C.
|
||||
- Each option gets a brief description, pros, and cons.
|
||||
- Options should represent genuinely different strategies, not minor variations of the same approach.
|
||||
|
||||
3. **Recommendation** (one option)
|
||||
- State which option you recommend and why, based on the user's context and priorities.
|
||||
- Be direct — this is your professional judgment, not a hedge.
|
||||
|
||||
4. **Definition of Done**
|
||||
- List the specific success criteria for the recommended option.
|
||||
- These are concrete, verifiable outcomes — not vague aspirations.
|
||||
- If the user picks a different option, revise this section to match.
|
||||
|
||||
5. **Implementation Plan**
|
||||
- Concrete steps to execute the recommended option.
|
||||
- Include specific commands, tools, or actions where applicable.
|
||||
- If the user picks a different option, revise this section to match.
|
||||
|
||||
## Verification
|
||||
|
||||
- Response has exactly one Problem sentence.
|
||||
- Response has exactly three Options (A, B, C) with pros and cons for each.
|
||||
- Response has a single Recommendation that picks one option with reasoning.
|
||||
- Definition of Done and Implementation Plan align with the recommended option.
|
||||
- If the user selects a different option, Recommendation, DoD, and Implementation Plan update accordingly.
|
||||
|
||||
## Example
|
||||
|
||||
User: "Give me a 1-3-1 for adding retry logic to our API client."
|
||||
|
||||
Agent:
|
||||
|
||||
**Problem:** API calls to external services fail intermittently under load and need graceful retry handling.
|
||||
|
||||
**Options:**
|
||||
|
||||
- **Option A: Exponential backoff with jitter in a shared utility.**
|
||||
Pros: Simple, well-understood pattern; single implementation shared across all callers.
|
||||
Cons: No circuit-breaking; callers still block during retries.
|
||||
|
||||
- **Option B: Retry library (tenacity) as a decorator on each endpoint call.**
|
||||
Pros: Configurable per-endpoint; built-in support for circuit breakers and rate limits.
|
||||
Cons: Adds a dependency; decorator pattern can obscure control flow.
|
||||
|
||||
- **Option C: Queue-based async retry with dead letter handling.**
|
||||
Pros: Non-blocking; failed calls are tracked and replayable; scales well.
|
||||
Cons: Significant complexity increase; requires a queue backend; overkill for low-volume services.
|
||||
|
||||
**Recommendation:** Option A. The call volume doesn't justify queue infrastructure, and a shared utility with exponential backoff + jitter covers the failure mode cleanly without new dependencies.
|
||||
|
||||
**Definition of Done:**
|
||||
- `api_retry()` utility handles transient HTTP errors (429, 502, 503, 504).
|
||||
- Retries up to 3 times with exponential backoff and jitter.
|
||||
- All external API calls route through the retry utility.
|
||||
- Logs each retry attempt with endpoint and attempt number.
|
||||
- Tests cover retry exhaustion, immediate success, and transient-then-success scenarios.
|
||||
|
||||
**Implementation Plan:**
|
||||
1. Create `utils/api_retry.py` with configurable max retries, base delay, and retryable status codes.
|
||||
2. Add jitter using `random.uniform(0, base_delay)` to prevent thundering herd.
|
||||
3. Wrap existing API calls in `api_client.py` with the retry utility.
|
||||
4. Add unit tests mocking HTTP responses for each retry scenario.
|
||||
5. Verify under load with a simple stress test against a flaky endpoint mock.
|
||||
@@ -0,0 +1,97 @@
|
||||
---
|
||||
name: canvas
|
||||
description: Canvas LMS integration — fetch enrolled courses and assignments using API token authentication.
|
||||
version: 1.0.0
|
||||
author: community
|
||||
license: MIT
|
||||
prerequisites:
|
||||
env_vars: [CANVAS_API_TOKEN, CANVAS_BASE_URL]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Canvas, LMS, Education, Courses, Assignments]
|
||||
---
|
||||
|
||||
# Canvas LMS — Course & Assignment Access
|
||||
|
||||
Read-only access to Canvas LMS for listing courses and assignments.
|
||||
|
||||
## Scripts
|
||||
|
||||
- `scripts/canvas_api.py` — Python CLI for Canvas API calls
|
||||
|
||||
## Setup
|
||||
|
||||
1. Log in to your Canvas instance in a browser
|
||||
2. Go to **Account → Settings** (click your profile icon, then Settings)
|
||||
3. Scroll to **Approved Integrations** and click **+ New Access Token**
|
||||
4. Name the token (e.g., "Hermes Agent"), set an optional expiry, and click **Generate Token**
|
||||
5. Copy the token and add to `~/.hermes/.env`:
|
||||
|
||||
```
|
||||
CANVAS_API_TOKEN=your_token_here
|
||||
CANVAS_BASE_URL=https://yourschool.instructure.com
|
||||
```
|
||||
|
||||
The base URL is whatever appears in your browser when you're logged into Canvas (no trailing slash).
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
CANVAS="python $HERMES_HOME/skills/productivity/canvas/scripts/canvas_api.py"
|
||||
|
||||
# List all active courses
|
||||
$CANVAS list_courses --enrollment-state active
|
||||
|
||||
# List all courses (any state)
|
||||
$CANVAS list_courses
|
||||
|
||||
# List assignments for a specific course
|
||||
$CANVAS list_assignments 12345
|
||||
|
||||
# List assignments ordered by due date
|
||||
$CANVAS list_assignments 12345 --order-by due_at
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
**list_courses** returns:
|
||||
```json
|
||||
[{"id": 12345, "name": "Intro to CS", "course_code": "CS101", "workflow_state": "available", "start_at": "...", "end_at": "..."}]
|
||||
```
|
||||
|
||||
**list_assignments** returns:
|
||||
```json
|
||||
[{"id": 67890, "name": "Homework 1", "due_at": "2025-02-15T23:59:00Z", "points_possible": 100, "submission_types": ["online_upload"], "html_url": "...", "description": "...", "course_id": 12345}]
|
||||
```
|
||||
|
||||
Note: Assignment descriptions are truncated to 500 characters. The `html_url` field links to the full assignment page in Canvas.
|
||||
|
||||
## API Reference (curl)
|
||||
|
||||
```bash
|
||||
# List courses
|
||||
curl -s -H "Authorization: Bearer $CANVAS_API_TOKEN" \
|
||||
"$CANVAS_BASE_URL/api/v1/courses?enrollment_state=active&per_page=10"
|
||||
|
||||
# List assignments for a course
|
||||
curl -s -H "Authorization: Bearer $CANVAS_API_TOKEN" \
|
||||
"$CANVAS_BASE_URL/api/v1/courses/COURSE_ID/assignments?per_page=10&order_by=due_at"
|
||||
```
|
||||
|
||||
Canvas uses `Link` headers for pagination. The Python script handles pagination automatically.
|
||||
|
||||
## Rules
|
||||
|
||||
- This skill is **read-only** — it only fetches data, never modifies courses or assignments
|
||||
- On first use, verify auth by running `$CANVAS list_courses` — if it fails with 401, guide the user through setup
|
||||
- Canvas rate-limits to ~700 requests per 10 minutes; check `X-Rate-Limit-Remaining` header if hitting limits
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Fix |
|
||||
|---------|-----|
|
||||
| 401 Unauthorized | Token invalid or expired — regenerate in Canvas Settings |
|
||||
| 403 Forbidden | Token lacks permission for this course |
|
||||
| Empty course list | Try `--enrollment-state active` or omit the flag to see all states |
|
||||
| Wrong institution | Verify `CANVAS_BASE_URL` matches the URL in your browser |
|
||||
| Timeout errors | Check network connectivity to your Canvas instance |
|
||||
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Canvas LMS API CLI for Hermes Agent.
|
||||
|
||||
A thin CLI wrapper around the Canvas REST API.
|
||||
Authenticates using a personal access token from environment variables.
|
||||
|
||||
Usage:
|
||||
python canvas_api.py list_courses [--per-page N] [--enrollment-state STATE]
|
||||
python canvas_api.py list_assignments COURSE_ID [--per-page N] [--order-by FIELD]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
CANVAS_API_TOKEN = os.environ.get("CANVAS_API_TOKEN", "")
|
||||
CANVAS_BASE_URL = os.environ.get("CANVAS_BASE_URL", "").rstrip("/")
|
||||
|
||||
|
||||
def _check_config():
|
||||
"""Validate required environment variables are set."""
|
||||
missing = []
|
||||
if not CANVAS_API_TOKEN:
|
||||
missing.append("CANVAS_API_TOKEN")
|
||||
if not CANVAS_BASE_URL:
|
||||
missing.append("CANVAS_BASE_URL")
|
||||
if missing:
|
||||
print(
|
||||
f"Missing required environment variables: {', '.join(missing)}\n"
|
||||
"Set them in ~/.hermes/.env or export them in your shell.\n"
|
||||
"See the canvas skill SKILL.md for setup instructions.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _headers():
|
||||
return {"Authorization": f"Bearer {CANVAS_API_TOKEN}"}
|
||||
|
||||
|
||||
def _paginated_get(url, params=None, max_items=200):
|
||||
"""Fetch all pages up to max_items, following Canvas Link headers."""
|
||||
results = []
|
||||
while url and len(results) < max_items:
|
||||
resp = requests.get(url, headers=_headers(), params=params, timeout=30)
|
||||
resp.raise_for_status()
|
||||
results.extend(resp.json())
|
||||
params = None # params are included in the Link URL for subsequent pages
|
||||
url = None
|
||||
link = resp.headers.get("Link", "")
|
||||
for part in link.split(","):
|
||||
if 'rel="next"' in part:
|
||||
url = part.split(";")[0].strip().strip("<>")
|
||||
return results[:max_items]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Commands
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def list_courses(args):
|
||||
"""List enrolled courses."""
|
||||
_check_config()
|
||||
url = f"{CANVAS_BASE_URL}/api/v1/courses"
|
||||
params = {"per_page": args.per_page}
|
||||
if args.enrollment_state:
|
||||
params["enrollment_state"] = args.enrollment_state
|
||||
try:
|
||||
courses = _paginated_get(url, params)
|
||||
except requests.HTTPError as e:
|
||||
print(f"API error: {e.response.status_code} {e.response.text}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
output = [
|
||||
{
|
||||
"id": c["id"],
|
||||
"name": c.get("name", ""),
|
||||
"course_code": c.get("course_code", ""),
|
||||
"enrollment_term_id": c.get("enrollment_term_id"),
|
||||
"start_at": c.get("start_at"),
|
||||
"end_at": c.get("end_at"),
|
||||
"workflow_state": c.get("workflow_state", ""),
|
||||
}
|
||||
for c in courses
|
||||
]
|
||||
print(json.dumps(output, indent=2))
|
||||
|
||||
|
||||
def list_assignments(args):
|
||||
"""List assignments for a course."""
|
||||
_check_config()
|
||||
url = f"{CANVAS_BASE_URL}/api/v1/courses/{args.course_id}/assignments"
|
||||
params = {"per_page": args.per_page}
|
||||
if args.order_by:
|
||||
params["order_by"] = args.order_by
|
||||
try:
|
||||
assignments = _paginated_get(url, params)
|
||||
except requests.HTTPError as e:
|
||||
print(f"API error: {e.response.status_code} {e.response.text}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
output = [
|
||||
{
|
||||
"id": a["id"],
|
||||
"name": a.get("name", ""),
|
||||
"description": (a.get("description") or "")[:500],
|
||||
"due_at": a.get("due_at"),
|
||||
"points_possible": a.get("points_possible"),
|
||||
"submission_types": a.get("submission_types", []),
|
||||
"html_url": a.get("html_url", ""),
|
||||
"course_id": a.get("course_id"),
|
||||
}
|
||||
for a in assignments
|
||||
]
|
||||
print(json.dumps(output, indent=2))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CLI parser
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Canvas LMS API CLI for Hermes Agent"
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# --- list_courses ---
|
||||
p = sub.add_parser("list_courses", help="List enrolled courses")
|
||||
p.add_argument("--per-page", type=int, default=50, help="Results per page (default 50)")
|
||||
p.add_argument(
|
||||
"--enrollment-state",
|
||||
default="",
|
||||
help="Filter by enrollment state (active, invited_or_pending, completed)",
|
||||
)
|
||||
p.set_defaults(func=list_courses)
|
||||
|
||||
# --- list_assignments ---
|
||||
p = sub.add_parser("list_assignments", help="List assignments for a course")
|
||||
p.add_argument("course_id", help="Canvas course ID")
|
||||
p.add_argument("--per-page", type=int, default=50, help="Results per page (default 50)")
|
||||
p.add_argument(
|
||||
"--order-by",
|
||||
default="",
|
||||
help="Order by field (due_at, name, position)",
|
||||
)
|
||||
p.set_defaults(func=list_assignments)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,324 @@
|
||||
---
|
||||
name: memento-flashcards
|
||||
description: >-
|
||||
Spaced-repetition flashcard system. Create cards from facts or text,
|
||||
chat with flashcards using free-text answers graded by the agent,
|
||||
generate quizzes from YouTube transcripts, review due cards with
|
||||
adaptive scheduling, and export/import decks as CSV.
|
||||
version: 1.0.0
|
||||
author: Memento AI
|
||||
license: MIT
|
||||
platforms: [macos, linux]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Education, Flashcards, Spaced Repetition, Learning, Quiz, YouTube]
|
||||
requires_toolsets: [terminal]
|
||||
category: productivity
|
||||
---
|
||||
|
||||
# Memento Flashcards — Spaced-Repetition Flashcard Skill
|
||||
|
||||
## Overview
|
||||
|
||||
Memento gives you a local, file-based flashcard system with spaced-repetition scheduling.
|
||||
Users can chat with their flashcards by answering in free text and having the agent grade the response before scheduling the next review.
|
||||
Use it whenever the user wants to:
|
||||
|
||||
- **Remember a fact** — turn any statement into a Q/A flashcard
|
||||
- **Study with spaced repetition** — review due cards with adaptive intervals and agent-graded free-text answers
|
||||
- **Quiz from a YouTube video** — fetch a transcript and generate a 5-question quiz
|
||||
- **Manage decks** — organise cards into collections, export/import CSV
|
||||
|
||||
All card data lives in a single JSON file. No external API keys are required — you (the agent) generate flashcard content and quiz questions directly.
|
||||
|
||||
User-facing response style for Memento Flashcards:
|
||||
- Use plain text only. Do not use Markdown formatting in replies to the user.
|
||||
- Keep review and quiz feedback brief and neutral. Avoid extra praise, pep, or long explanations.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this skill when the user wants to:
|
||||
- Save facts as flashcards for later review
|
||||
- Review due cards with spaced repetition
|
||||
- Generate a quiz from a YouTube video transcript
|
||||
- Import, export, inspect, or delete flashcard data
|
||||
|
||||
Do not use this skill for general Q&A, coding help, or non-memory tasks.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
| User intent | Action |
|
||||
|---|---|
|
||||
| "Remember that X" / "save this as a flashcard" | Generate a Q/A card, call `memento_cards.py add` |
|
||||
| Sends a fact without mentioning flashcards | Ask "Want me to save this as a Memento flashcard?" — only create if confirmed |
|
||||
| "Create a flashcard" | Ask for Q, A, collection; call `memento_cards.py add` |
|
||||
| "Review my cards" | Call `memento_cards.py due`, present cards one-by-one |
|
||||
| "Quiz me on [YouTube URL]" | Call `youtube_quiz.py fetch VIDEO_ID`, generate 5 questions, call `memento_cards.py add-quiz` |
|
||||
| "Export my cards" | Call `memento_cards.py export --output PATH` |
|
||||
| "Import cards from CSV" | Call `memento_cards.py import --file PATH --collection NAME` |
|
||||
| "Show my stats" | Call `memento_cards.py stats` |
|
||||
| "Delete a card" | Call `memento_cards.py delete --id ID` |
|
||||
| "Delete a collection" | Call `memento_cards.py delete-collection --collection NAME` |
|
||||
|
||||
## Card Storage
|
||||
|
||||
Cards are stored in a JSON file at:
|
||||
|
||||
```
|
||||
~/.hermes/skills/productivity/memento-flashcards/data/cards.json
|
||||
```
|
||||
|
||||
**Never edit this file directly.** Always use `memento_cards.py` subcommands. The script handles atomic writes (write to temp file, then rename) to prevent corruption.
|
||||
|
||||
The file is created automatically on first use.
|
||||
|
||||
## Procedure
|
||||
|
||||
### Creating Cards from Facts
|
||||
|
||||
### Activation Rules
|
||||
|
||||
Not every factual statement should become a flashcard. Use this three-tier check:
|
||||
|
||||
1. **Explicit intent** — the user mentions "memento", "flashcard", "remember this", "save this card", "add a card", or similar phrasing that clearly requests a flashcard → **create the card directly**, no confirmation needed.
|
||||
2. **Implicit intent** — the user sends a factual statement without mentioning flashcards (e.g. "The speed of light is 299,792 km/s") → **ask first**: "Want me to save this as a Memento flashcard?" Only create the card if the user confirms.
|
||||
3. **No intent** — the message is a coding task, a question, instructions, normal conversation, or anything that is clearly not a fact to memorize → **do NOT activate this skill at all**. Let other skills or default behavior handle it.
|
||||
|
||||
When activation is confirmed (tier 1 directly, tier 2 after confirmation), generate a flashcard:
|
||||
|
||||
**Step 1:** Turn the statement into a Q/A pair. Use this format internally:
|
||||
|
||||
```
|
||||
Turn the factual statement into a front-back pair.
|
||||
Return exactly two lines:
|
||||
Q: <question text>
|
||||
A: <answer text>
|
||||
|
||||
Statement: "{statement}"
|
||||
```
|
||||
|
||||
Rules:
|
||||
- The question should test recall of the key fact
|
||||
- The answer should be concise and direct
|
||||
|
||||
**Step 2:** Call the script to store the card:
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py add \
|
||||
--question "What year did World War 2 end?" \
|
||||
--answer "1945" \
|
||||
--collection "History"
|
||||
```
|
||||
|
||||
If the user doesn't specify a collection, use `"General"` as the default.
|
||||
|
||||
The script outputs JSON confirming the created card.
|
||||
|
||||
### Manual Card Creation
|
||||
|
||||
When the user explicitly asks to create a flashcard, ask them for:
|
||||
1. The question (front of card)
|
||||
2. The answer (back of card)
|
||||
3. The collection name (optional — default to `"General"`)
|
||||
|
||||
Then call `memento_cards.py add` as above.
|
||||
|
||||
### Reviewing Due Cards
|
||||
|
||||
When the user wants to review, fetch all due cards:
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py due
|
||||
```
|
||||
|
||||
This returns a JSON array of cards where `next_review_at <= now`. If a collection filter is needed:
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py due --collection "History"
|
||||
```
|
||||
|
||||
**Review flow (free-text grading):**
|
||||
|
||||
Here is an example of the EXACT interaction pattern you must follow. The user answers, you grade them, tell them the correct answer, then rate the card.
|
||||
|
||||
**Example interaction:**
|
||||
|
||||
> **Agent:** What year did the Berlin Wall fall?
|
||||
>
|
||||
> **User:** 1991
|
||||
>
|
||||
> **Agent:** Not quite. The Berlin Wall fell in 1989. Next review is tomorrow.
|
||||
> *(agent calls: memento_cards.py rate --id ABC --rating hard --user-answer "1991")*
|
||||
>
|
||||
> Next question: Who was the first person to walk on the moon?
|
||||
|
||||
**The rules:**
|
||||
|
||||
1. Show only the question. Wait for the user to answer.
|
||||
2. After receiving their answer, compare it to the expected answer and grade it:
|
||||
- **correct** → user got the key fact right (even if worded differently)
|
||||
- **partial** → right track but missing the core detail
|
||||
- **incorrect** → wrong or off-topic
|
||||
3. **You MUST tell the user the correct answer and how they did.** Keep it short and plain-text. Use this format:
|
||||
- correct: "Correct. Answer: {answer}. Next review in 7 days."
|
||||
- partial: "Close. Answer: {answer}. {what they missed}. Next review in 3 days."
|
||||
- incorrect: "Not quite. Answer: {answer}. Next review tomorrow."
|
||||
4. Then call the rate command: correct→easy, partial→good, incorrect→hard.
|
||||
5. Then show the next question.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py rate \
|
||||
--id CARD_ID --rating easy --user-answer "what the user said"
|
||||
```
|
||||
|
||||
**Never skip step 3.** The user must always see the correct answer and feedback before you move on.
|
||||
|
||||
If no cards are due, tell the user: "No cards due for review right now. Check back later!"
|
||||
|
||||
**Retire override:** At any point the user can say "retire this card" to permanently remove it from reviews. Use `--rating retire` for this.
|
||||
|
||||
### Spaced Repetition Algorithm
|
||||
|
||||
The rating determines the next review interval:
|
||||
|
||||
| Rating | Interval | ease_streak | Status change |
|
||||
|---|---|---|---|
|
||||
| **hard** | +1 day | reset to 0 | stays learning |
|
||||
| **good** | +3 days | reset to 0 | stays learning |
|
||||
| **easy** | +7 days | +1 | if ease_streak >= 3 → retired |
|
||||
| **retire** | permanent | reset to 0 | → retired |
|
||||
|
||||
- **learning**: card is actively in rotation
|
||||
- **retired**: card won't appear in reviews (user has mastered it or manually retired it)
|
||||
- Three consecutive "easy" ratings automatically retire a card
|
||||
|
||||
### YouTube Quiz Generation
|
||||
|
||||
When the user sends a YouTube URL and wants a quiz:
|
||||
|
||||
**Step 1:** Extract the video ID from the URL (e.g. `dQw4w9WgXcQ` from `https://www.youtube.com/watch?v=dQw4w9WgXcQ`).
|
||||
|
||||
**Step 2:** Fetch the transcript:
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/youtube_quiz.py fetch VIDEO_ID
|
||||
```
|
||||
|
||||
This returns `{"title": "...", "transcript": "..."}` or an error.
|
||||
|
||||
If the script reports `missing_dependency`, tell the user to install it:
|
||||
```bash
|
||||
pip install youtube-transcript-api
|
||||
```
|
||||
|
||||
**Step 3:** Generate 5 quiz questions from the transcript. Use these rules:
|
||||
|
||||
```
|
||||
You are creating a 5-question quiz for a podcast episode.
|
||||
Return ONLY a JSON array with exactly 5 objects.
|
||||
Each object must contain keys 'question' and 'answer'.
|
||||
|
||||
Selection criteria:
|
||||
- Prioritize important, surprising, or foundational facts.
|
||||
- Skip filler, obvious details, and facts that require heavy context.
|
||||
- Never return true/false questions.
|
||||
- Never ask only for a date.
|
||||
|
||||
Question rules:
|
||||
- Each question must test exactly one discrete fact.
|
||||
- Use clear, unambiguous wording.
|
||||
- Prefer What, Who, How many, Which.
|
||||
- Avoid open-ended Describe or Explain prompts.
|
||||
|
||||
Answer rules:
|
||||
- Each answer must be under 240 characters.
|
||||
- Lead with the answer itself, not preamble.
|
||||
- Add only minimal clarifying detail if needed.
|
||||
```
|
||||
|
||||
Use the first 15,000 characters of the transcript as context. Generate the questions yourself (you are the LLM).
|
||||
|
||||
**Step 4:** Validate the output is valid JSON with exactly 5 items, each having non-empty `question` and `answer` strings. If validation fails, retry once.
|
||||
|
||||
**Step 5:** Store quiz cards:
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py add-quiz \
|
||||
--video-id "VIDEO_ID" \
|
||||
--questions '[{"question":"...","answer":"..."},...]' \
|
||||
--collection "Quiz - Episode Title"
|
||||
```
|
||||
|
||||
The script deduplicates by `video_id` — if cards for that video already exist, it skips creation and reports the existing cards.
|
||||
|
||||
**Step 6:** Present questions one-by-one using the same free-text grading flow:
|
||||
1. Show "Question 1/5: ..." and wait for the user's answer. Never include the answer or any hint about revealing it.
|
||||
2. Wait for the user to answer in their own words
|
||||
3. Grade their answer using the grading prompt (see "Reviewing Due Cards" section)
|
||||
4. **IMPORTANT: You MUST reply to the user with feedback before doing anything else.** Show the grade, the correct answer, and when the card is next due. Do NOT silently skip to the next question. Keep it short and plain-text. Example: "Not quite. Answer: {answer}. Next review tomorrow."
|
||||
5. **After showing feedback**, call the rate command and then show the next question in the same message:
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py rate \
|
||||
--id CARD_ID --rating easy --user-answer "what the user said"
|
||||
```
|
||||
6. Repeat. Every answer MUST receive visible feedback before the next question.
|
||||
|
||||
### Export/Import CSV
|
||||
|
||||
**Export:**
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py export \
|
||||
--output ~/flashcards.csv
|
||||
```
|
||||
|
||||
Produces a 3-column CSV: `question,answer,collection` (no header row).
|
||||
|
||||
**Import:**
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py import \
|
||||
--file ~/flashcards.csv \
|
||||
--collection "Imported"
|
||||
```
|
||||
|
||||
Reads a CSV with columns: question, answer, and optionally collection (column 3). If the collection column is missing, uses the `--collection` argument.
|
||||
|
||||
### Statistics
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py stats
|
||||
```
|
||||
|
||||
Returns JSON with:
|
||||
- `total`: total card count
|
||||
- `learning`: cards in active rotation
|
||||
- `retired`: mastered cards
|
||||
- `due_now`: cards due for review right now
|
||||
- `collections`: breakdown by collection name
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- **Never edit `cards.json` directly** — always use the script subcommands to avoid corruption
|
||||
- **Transcript failures** — some YouTube videos have no English transcript or have transcripts disabled; inform the user and suggest another video
|
||||
- **Optional dependency** — `youtube_quiz.py` needs `youtube-transcript-api`; if missing, tell the user to run `pip install youtube-transcript-api`
|
||||
- **Large imports** — CSV imports with thousands of rows work fine but the JSON output may be verbose; summarize the result for the user
|
||||
- **Video ID extraction** — support both `youtube.com/watch?v=ID` and `youtu.be/ID` URL formats
|
||||
|
||||
## Verification
|
||||
|
||||
Verify the helper scripts directly:
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py stats
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py add --question "Capital of France?" --answer "Paris" --collection "General"
|
||||
python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py due
|
||||
```
|
||||
|
||||
If you are testing from the repo checkout, run:
|
||||
|
||||
```bash
|
||||
pytest tests/skills/test_memento_cards.py tests/skills/test_youtube_quiz.py -q
|
||||
```
|
||||
|
||||
Agent-level verification:
|
||||
- Start a review and confirm feedback is plain text, brief, and always includes the correct answer before the next card
|
||||
- Run a YouTube quiz flow and confirm each answer receives visible feedback before the next question
|
||||
@@ -0,0 +1,353 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Memento card storage, spaced-repetition engine, and CSV I/O.
|
||||
|
||||
Stdlib-only. All output is JSON for agent parsing.
|
||||
Data file: $HERMES_HOME/skills/productivity/memento-flashcards/data/cards.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
_HERMES_HOME = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
DATA_DIR = _HERMES_HOME / "skills" / "productivity" / "memento-flashcards" / "data"
|
||||
CARDS_FILE = DATA_DIR / "cards.json"
|
||||
|
||||
RETIRED_SENTINEL = "9999-12-31T23:59:59+00:00"
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _iso(dt: datetime) -> str:
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
def _parse_iso(s: str) -> datetime:
|
||||
return datetime.fromisoformat(s)
|
||||
|
||||
|
||||
def _empty_store() -> dict:
|
||||
return {"cards": [], "version": 1}
|
||||
|
||||
|
||||
def _load() -> dict:
|
||||
if not CARDS_FILE.exists():
|
||||
return _empty_store()
|
||||
try:
|
||||
with open(CARDS_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict) or "cards" not in data:
|
||||
return _empty_store()
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return _empty_store()
|
||||
|
||||
|
||||
def _save(data: dict) -> None:
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp = tempfile.mkstemp(dir=DATA_DIR, suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
os.replace(tmp, CARDS_FILE)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _out(obj: object) -> None:
|
||||
json.dump(obj, sys.stdout, indent=2, ensure_ascii=False)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
# ── Subcommands ──────────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_add(args: argparse.Namespace) -> None:
|
||||
data = _load()
|
||||
now = _now()
|
||||
card = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"question": args.question,
|
||||
"answer": args.answer,
|
||||
"collection": args.collection or "General",
|
||||
"status": "learning",
|
||||
"ease_streak": 0,
|
||||
"next_review_at": _iso(now),
|
||||
"created_at": _iso(now),
|
||||
"video_id": None,
|
||||
"last_user_answer": None,
|
||||
}
|
||||
data["cards"].append(card)
|
||||
_save(data)
|
||||
_out({"ok": True, "card": card})
|
||||
|
||||
|
||||
def cmd_add_quiz(args: argparse.Namespace) -> None:
|
||||
data = _load()
|
||||
now = _now()
|
||||
|
||||
try:
|
||||
questions = json.loads(args.questions)
|
||||
except json.JSONDecodeError as exc:
|
||||
_out({"ok": False, "error": f"Invalid JSON for --questions: {exc}"})
|
||||
sys.exit(1)
|
||||
|
||||
# Dedup: skip if cards with this video_id already exist
|
||||
existing_ids = {c["video_id"] for c in data["cards"] if c.get("video_id")}
|
||||
if args.video_id in existing_ids:
|
||||
existing = [c for c in data["cards"] if c.get("video_id") == args.video_id]
|
||||
_out({"ok": True, "skipped": True, "reason": "duplicate_video_id", "existing_count": len(existing), "cards": existing})
|
||||
return
|
||||
|
||||
created = []
|
||||
for qa in questions:
|
||||
card = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"question": qa["question"],
|
||||
"answer": qa["answer"],
|
||||
"collection": args.collection or "Quiz",
|
||||
"status": "learning",
|
||||
"ease_streak": 0,
|
||||
"next_review_at": _iso(now),
|
||||
"created_at": _iso(now),
|
||||
"video_id": args.video_id,
|
||||
"last_user_answer": None,
|
||||
}
|
||||
data["cards"].append(card)
|
||||
created.append(card)
|
||||
|
||||
_save(data)
|
||||
_out({"ok": True, "created_count": len(created), "cards": created})
|
||||
|
||||
|
||||
def cmd_due(args: argparse.Namespace) -> None:
|
||||
data = _load()
|
||||
now = _now()
|
||||
due = []
|
||||
for card in data["cards"]:
|
||||
if card["status"] == "retired":
|
||||
continue
|
||||
review_at = _parse_iso(card["next_review_at"])
|
||||
if review_at <= now:
|
||||
if args.collection and card["collection"] != args.collection:
|
||||
continue
|
||||
due.append(card)
|
||||
_out({"ok": True, "count": len(due), "cards": due})
|
||||
|
||||
|
||||
def cmd_rate(args: argparse.Namespace) -> None:
|
||||
data = _load()
|
||||
now = _now()
|
||||
card = None
|
||||
for c in data["cards"]:
|
||||
if c["id"] == args.id:
|
||||
card = c
|
||||
break
|
||||
if not card:
|
||||
_out({"ok": False, "error": f"Card not found: {args.id}"})
|
||||
sys.exit(1)
|
||||
|
||||
rating = args.rating
|
||||
user_answer = getattr(args, "user_answer", None)
|
||||
if user_answer is not None:
|
||||
card["last_user_answer"] = user_answer
|
||||
|
||||
if rating == "retire":
|
||||
card["status"] = "retired"
|
||||
card["next_review_at"] = RETIRED_SENTINEL
|
||||
card["ease_streak"] = 0
|
||||
elif rating == "hard":
|
||||
card["next_review_at"] = _iso(now + timedelta(days=1))
|
||||
card["ease_streak"] = 0
|
||||
elif rating == "good":
|
||||
card["next_review_at"] = _iso(now + timedelta(days=3))
|
||||
card["ease_streak"] = 0
|
||||
elif rating == "easy":
|
||||
card["next_review_at"] = _iso(now + timedelta(days=7))
|
||||
card["ease_streak"] = card.get("ease_streak", 0) + 1
|
||||
if card["ease_streak"] >= 3:
|
||||
card["status"] = "retired"
|
||||
|
||||
_save(data)
|
||||
_out({"ok": True, "card": card})
|
||||
|
||||
|
||||
def cmd_list(args: argparse.Namespace) -> None:
|
||||
data = _load()
|
||||
cards = data["cards"]
|
||||
if args.collection:
|
||||
cards = [c for c in cards if c["collection"] == args.collection]
|
||||
if args.status:
|
||||
cards = [c for c in cards if c["status"] == args.status]
|
||||
_out({"ok": True, "count": len(cards), "cards": cards})
|
||||
|
||||
|
||||
def cmd_stats(args: argparse.Namespace) -> None:
|
||||
data = _load()
|
||||
now = _now()
|
||||
total = len(data["cards"])
|
||||
learning = sum(1 for c in data["cards"] if c["status"] == "learning")
|
||||
retired = sum(1 for c in data["cards"] if c["status"] == "retired")
|
||||
due_now = 0
|
||||
for c in data["cards"]:
|
||||
if c["status"] != "retired" and _parse_iso(c["next_review_at"]) <= now:
|
||||
due_now += 1
|
||||
|
||||
collections: dict[str, int] = {}
|
||||
for c in data["cards"]:
|
||||
name = c["collection"]
|
||||
collections[name] = collections.get(name, 0) + 1
|
||||
|
||||
_out({
|
||||
"ok": True,
|
||||
"total": total,
|
||||
"learning": learning,
|
||||
"retired": retired,
|
||||
"due_now": due_now,
|
||||
"collections": collections,
|
||||
})
|
||||
|
||||
|
||||
def cmd_export(args: argparse.Namespace) -> None:
|
||||
data = _load()
|
||||
output_path = Path(args.output).expanduser()
|
||||
with open(output_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.writer(f, lineterminator="\n")
|
||||
for card in data["cards"]:
|
||||
writer.writerow([card["question"], card["answer"], card["collection"]])
|
||||
_out({"ok": True, "exported": len(data["cards"]), "path": str(output_path)})
|
||||
|
||||
|
||||
def cmd_import(args: argparse.Namespace) -> None:
|
||||
data = _load()
|
||||
now = _now()
|
||||
file_path = Path(args.file).expanduser()
|
||||
|
||||
if not file_path.exists():
|
||||
_out({"ok": False, "error": f"File not found: {file_path}"})
|
||||
sys.exit(1)
|
||||
|
||||
created = 0
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
for row in reader:
|
||||
if len(row) < 2:
|
||||
continue
|
||||
question = row[0].strip()
|
||||
answer = row[1].strip()
|
||||
collection = row[2].strip() if len(row) >= 3 and row[2].strip() else (args.collection or "Imported")
|
||||
if not question or not answer:
|
||||
continue
|
||||
card = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"collection": collection,
|
||||
"status": "learning",
|
||||
"ease_streak": 0,
|
||||
"next_review_at": _iso(now),
|
||||
"created_at": _iso(now),
|
||||
"video_id": None,
|
||||
"last_user_answer": None,
|
||||
}
|
||||
data["cards"].append(card)
|
||||
created += 1
|
||||
|
||||
_save(data)
|
||||
_out({"ok": True, "imported": created})
|
||||
|
||||
|
||||
def cmd_delete(args: argparse.Namespace) -> None:
|
||||
data = _load()
|
||||
original = len(data["cards"])
|
||||
data["cards"] = [c for c in data["cards"] if c["id"] != args.id]
|
||||
removed = original - len(data["cards"])
|
||||
if removed == 0:
|
||||
_out({"ok": False, "error": f"Card not found: {args.id}"})
|
||||
sys.exit(1)
|
||||
_save(data)
|
||||
_out({"ok": True, "deleted": args.id})
|
||||
|
||||
|
||||
def cmd_delete_collection(args: argparse.Namespace) -> None:
|
||||
data = _load()
|
||||
original = len(data["cards"])
|
||||
data["cards"] = [c for c in data["cards"] if c["collection"] != args.collection]
|
||||
removed = original - len(data["cards"])
|
||||
_save(data)
|
||||
_out({"ok": True, "deleted_count": removed, "collection": args.collection})
|
||||
|
||||
|
||||
# ── CLI ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Memento flashcard manager")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
p_add = sub.add_parser("add", help="Create one card")
|
||||
p_add.add_argument("--question", required=True)
|
||||
p_add.add_argument("--answer", required=True)
|
||||
p_add.add_argument("--collection", default="General")
|
||||
|
||||
p_quiz = sub.add_parser("add-quiz", help="Batch-add quiz cards")
|
||||
p_quiz.add_argument("--video-id", required=True)
|
||||
p_quiz.add_argument("--questions", required=True, help="JSON array of {question, answer}")
|
||||
p_quiz.add_argument("--collection", default="Quiz")
|
||||
|
||||
p_due = sub.add_parser("due", help="List due cards")
|
||||
p_due.add_argument("--collection", default=None)
|
||||
|
||||
p_rate = sub.add_parser("rate", help="Rate a card")
|
||||
p_rate.add_argument("--id", required=True)
|
||||
p_rate.add_argument("--rating", required=True, choices=["easy", "good", "hard", "retire"])
|
||||
p_rate.add_argument("--user-answer", default=None)
|
||||
|
||||
p_list = sub.add_parser("list", help="List cards")
|
||||
p_list.add_argument("--collection", default=None)
|
||||
p_list.add_argument("--status", default=None, choices=["learning", "retired"])
|
||||
|
||||
sub.add_parser("stats", help="Show statistics")
|
||||
|
||||
p_export = sub.add_parser("export", help="Export cards to CSV")
|
||||
p_export.add_argument("--output", required=True)
|
||||
|
||||
p_import = sub.add_parser("import", help="Import cards from CSV")
|
||||
p_import.add_argument("--file", required=True)
|
||||
p_import.add_argument("--collection", default="Imported")
|
||||
|
||||
p_del = sub.add_parser("delete", help="Delete one card")
|
||||
p_del.add_argument("--id", required=True)
|
||||
|
||||
p_delcol = sub.add_parser("delete-collection", help="Delete all cards in a collection")
|
||||
p_delcol.add_argument("--collection", required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
cmd_map = {
|
||||
"add": cmd_add,
|
||||
"add-quiz": cmd_add_quiz,
|
||||
"due": cmd_due,
|
||||
"rate": cmd_rate,
|
||||
"list": cmd_list,
|
||||
"stats": cmd_stats,
|
||||
"export": cmd_export,
|
||||
"import": cmd_import,
|
||||
"delete": cmd_delete,
|
||||
"delete-collection": cmd_delete_collection,
|
||||
}
|
||||
cmd_map[args.command](args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Fetch YouTube transcripts for Memento quiz generation.
|
||||
|
||||
Requires: pip install youtube-transcript-api
|
||||
The quiz question *generation* is done by the agent's LLM — this script only fetches transcripts.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
def _out(obj: object) -> None:
|
||||
json.dump(obj, sys.stdout, indent=2, ensure_ascii=False)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
def _normalize_segments(segments: list) -> str:
|
||||
parts = []
|
||||
for seg in segments:
|
||||
text = str(seg.get("text", "")).strip()
|
||||
if text:
|
||||
parts.append(text)
|
||||
return re.sub(r"\s+", " ", " ".join(parts)).strip()
|
||||
|
||||
|
||||
def cmd_fetch(args: argparse.Namespace) -> None:
|
||||
try:
|
||||
import youtube_transcript_api # noqa: F811
|
||||
except ImportError:
|
||||
_out({
|
||||
"ok": False,
|
||||
"error": "missing_dependency",
|
||||
"message": "Run: pip install youtube-transcript-api",
|
||||
})
|
||||
sys.exit(1)
|
||||
|
||||
video_id = args.video_id
|
||||
languages = ["en", "en-US", "en-GB", "en-CA", "en-AU"]
|
||||
|
||||
api = youtube_transcript_api.YouTubeTranscriptApi()
|
||||
try:
|
||||
raw = api.fetch(video_id, languages=languages)
|
||||
except Exception as exc:
|
||||
error_type = type(exc).__name__
|
||||
_out({
|
||||
"ok": False,
|
||||
"error": "transcript_unavailable",
|
||||
"error_type": error_type,
|
||||
"message": f"Could not fetch transcript for {video_id}: {exc}",
|
||||
})
|
||||
sys.exit(1)
|
||||
|
||||
segments = raw
|
||||
if hasattr(raw, "to_raw_data"):
|
||||
segments = raw.to_raw_data()
|
||||
|
||||
text = _normalize_segments(segments)
|
||||
if not text:
|
||||
_out({
|
||||
"ok": False,
|
||||
"error": "empty_transcript",
|
||||
"message": f"Transcript for {video_id} contained no usable text.",
|
||||
})
|
||||
sys.exit(1)
|
||||
|
||||
_out({
|
||||
"ok": True,
|
||||
"video_id": video_id,
|
||||
"transcript": text,
|
||||
})
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Memento YouTube transcript fetcher")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
p_fetch = sub.add_parser("fetch", help="Fetch transcript for a video")
|
||||
p_fetch.add_argument("video_id", help="YouTube video ID")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.command == "fetch":
|
||||
cmd_fetch(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,297 @@
|
||||
---
|
||||
name: siyuan
|
||||
description: SiYuan Note API for searching, reading, creating, and managing blocks and documents in a self-hosted knowledge base via curl.
|
||||
version: 1.0.0
|
||||
author: FEUAZUR
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [SiYuan, Notes, Knowledge Base, PKM, API]
|
||||
related_skills: [obsidian, notion]
|
||||
homepage: https://github.com/siyuan-note/siyuan
|
||||
prerequisites:
|
||||
env_vars: [SIYUAN_TOKEN]
|
||||
commands: [curl, jq]
|
||||
required_environment_variables:
|
||||
- name: SIYUAN_TOKEN
|
||||
prompt: SiYuan API token
|
||||
help: "Settings > About in SiYuan desktop app"
|
||||
- name: SIYUAN_URL
|
||||
prompt: SiYuan instance URL (default http://127.0.0.1:6806)
|
||||
required_for: remote instances
|
||||
---
|
||||
|
||||
# SiYuan Note API
|
||||
|
||||
Use the [SiYuan](https://github.com/siyuan-note/siyuan) kernel API via curl to search, read, create, update, and delete blocks and documents in a self-hosted knowledge base. No extra tools needed -- just curl and an API token.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. Install and run SiYuan (desktop or Docker)
|
||||
2. Get your API token: **Settings > About > API token**
|
||||
3. Store it in `~/.hermes/.env`:
|
||||
```
|
||||
SIYUAN_TOKEN=your_token_here
|
||||
SIYUAN_URL=http://127.0.0.1:6806
|
||||
```
|
||||
`SIYUAN_URL` defaults to `http://127.0.0.1:6806` if not set.
|
||||
|
||||
## API Basics
|
||||
|
||||
All SiYuan API calls are **POST with JSON body**. Every request follows this pattern:
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/..." \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"param": "value"}'
|
||||
```
|
||||
|
||||
Responses are JSON with this structure:
|
||||
```json
|
||||
{"code": 0, "msg": "", "data": { ... }}
|
||||
```
|
||||
`code: 0` means success. Any other value is an error -- check `msg` for details.
|
||||
|
||||
**ID format:** SiYuan IDs look like `20210808180117-6v0mkxr` (14-digit timestamp + 7 alphanumeric chars).
|
||||
|
||||
## Quick Reference
|
||||
|
||||
| Operation | Endpoint |
|
||||
|-----------|----------|
|
||||
| Full-text search | `/api/search/fullTextSearchBlock` |
|
||||
| SQL query | `/api/query/sql` |
|
||||
| Read block | `/api/block/getBlockKramdown` |
|
||||
| Read children | `/api/block/getChildBlocks` |
|
||||
| Get path | `/api/filetree/getHPathByID` |
|
||||
| Get attributes | `/api/attr/getBlockAttrs` |
|
||||
| List notebooks | `/api/notebook/lsNotebooks` |
|
||||
| List documents | `/api/filetree/listDocsByPath` |
|
||||
| Create notebook | `/api/notebook/createNotebook` |
|
||||
| Create document | `/api/filetree/createDocWithMd` |
|
||||
| Append block | `/api/block/appendBlock` |
|
||||
| Update block | `/api/block/updateBlock` |
|
||||
| Rename document | `/api/filetree/renameDocByID` |
|
||||
| Set attributes | `/api/attr/setBlockAttrs` |
|
||||
| Delete block | `/api/block/deleteBlock` |
|
||||
| Delete document | `/api/filetree/removeDocByID` |
|
||||
| Export as Markdown | `/api/export/exportMdContent` |
|
||||
|
||||
## Common Operations
|
||||
|
||||
### Search (Full-Text)
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/search/fullTextSearchBlock" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"query": "meeting notes", "page": 0}' | jq '.data.blocks[:5]'
|
||||
```
|
||||
|
||||
### Search (SQL)
|
||||
|
||||
Query the blocks database directly. Only SELECT statements are safe.
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/query/sql" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"stmt": "SELECT id, content, type, box FROM blocks WHERE content LIKE '\''%keyword%'\'' AND type='\''p'\'' LIMIT 20"}' | jq '.data'
|
||||
```
|
||||
|
||||
Useful columns: `id`, `parent_id`, `root_id`, `box` (notebook ID), `path`, `content`, `type`, `subtype`, `created`, `updated`.
|
||||
|
||||
### Read Block Content
|
||||
|
||||
Returns block content in Kramdown (Markdown-like) format.
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/block/getBlockKramdown" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"id": "20210808180117-6v0mkxr"}' | jq '.data.kramdown'
|
||||
```
|
||||
|
||||
### Read Child Blocks
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/block/getChildBlocks" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"id": "20210808180117-6v0mkxr"}' | jq '.data'
|
||||
```
|
||||
|
||||
### Get Human-Readable Path
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/filetree/getHPathByID" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"id": "20210808180117-6v0mkxr"}' | jq '.data'
|
||||
```
|
||||
|
||||
### Get Block Attributes
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/attr/getBlockAttrs" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"id": "20210808180117-6v0mkxr"}' | jq '.data'
|
||||
```
|
||||
|
||||
### List Notebooks
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/notebook/lsNotebooks" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{}' | jq '.data.notebooks[] | {id, name, closed}'
|
||||
```
|
||||
|
||||
### List Documents in a Notebook
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/filetree/listDocsByPath" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"notebook": "NOTEBOOK_ID", "path": "/"}' | jq '.data.files[] | {id, name}'
|
||||
```
|
||||
|
||||
### Create a Document
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/filetree/createDocWithMd" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"notebook": "NOTEBOOK_ID",
|
||||
"path": "/Meeting Notes/2026-03-22",
|
||||
"markdown": "# Meeting Notes\n\n- Discussed project timeline\n- Assigned tasks"
|
||||
}' | jq '.data'
|
||||
```
|
||||
|
||||
### Create a Notebook
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/notebook/createNotebook" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"name": "My New Notebook"}' | jq '.data.notebook.id'
|
||||
```
|
||||
|
||||
### Append Block to Document
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/block/appendBlock" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"parentID": "DOCUMENT_OR_BLOCK_ID",
|
||||
"data": "New paragraph added at the end.",
|
||||
"dataType": "markdown"
|
||||
}' | jq '.data'
|
||||
```
|
||||
|
||||
Also available: `/api/block/prependBlock` (same params, inserts at the beginning) and `/api/block/insertBlock` (uses `previousID` instead of `parentID` to insert after a specific block).
|
||||
|
||||
### Update Block Content
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/block/updateBlock" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"id": "BLOCK_ID",
|
||||
"data": "Updated content here.",
|
||||
"dataType": "markdown"
|
||||
}' | jq '.data'
|
||||
```
|
||||
|
||||
### Rename a Document
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/filetree/renameDocByID" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"id": "DOCUMENT_ID", "title": "New Title"}'
|
||||
```
|
||||
|
||||
### Set Block Attributes
|
||||
|
||||
Custom attributes must be prefixed with `custom-`:
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/attr/setBlockAttrs" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"id": "BLOCK_ID",
|
||||
"attrs": {
|
||||
"custom-status": "reviewed",
|
||||
"custom-priority": "high"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### Delete a Block
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/block/deleteBlock" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"id": "BLOCK_ID"}'
|
||||
```
|
||||
|
||||
To delete a whole document: use `/api/filetree/removeDocByID` with `{"id": "DOC_ID"}`.
|
||||
To delete a notebook: use `/api/notebook/removeNotebook` with `{"notebook": "NOTEBOOK_ID"}`.
|
||||
|
||||
### Export Document as Markdown
|
||||
|
||||
```bash
|
||||
curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/export/exportMdContent" \
|
||||
-H "Authorization: Token $SIYUAN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"id": "DOCUMENT_ID"}' | jq -r '.data.content'
|
||||
```
|
||||
|
||||
## Block Types
|
||||
|
||||
Common `type` values in SQL queries:
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `d` | Document (root block) |
|
||||
| `p` | Paragraph |
|
||||
| `h` | Heading |
|
||||
| `l` | List |
|
||||
| `i` | List item |
|
||||
| `c` | Code block |
|
||||
| `m` | Math block |
|
||||
| `t` | Table |
|
||||
| `b` | Blockquote |
|
||||
| `s` | Super block |
|
||||
| `html` | HTML block |
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- **All endpoints are POST** -- even read-only operations. Do not use GET.
|
||||
- **SQL safety**: only use SELECT queries. INSERT/UPDATE/DELETE/DROP are dangerous and should never be sent.
|
||||
- **ID validation**: IDs match the pattern `YYYYMMDDHHmmss-xxxxxxx`. Reject anything else.
|
||||
- **Error responses**: always check `code != 0` in responses before processing `data`.
|
||||
- **Large documents**: block content and export results can be very large. Use `LIMIT` in SQL and pipe through `jq` to extract only what you need.
|
||||
- **Notebook IDs**: when working with a specific notebook, get its ID first via `lsNotebooks`.
|
||||
|
||||
## Alternative: MCP Server
|
||||
|
||||
If you prefer a native integration instead of curl, install the SiYuan MCP server:
|
||||
|
||||
```yaml
|
||||
# In ~/.hermes/config.yaml under mcp_servers:
|
||||
mcp_servers:
|
||||
siyuan:
|
||||
command: npx
|
||||
args: ["-y", "@porkll/siyuan-mcp"]
|
||||
env:
|
||||
SIYUAN_TOKEN: "your_token"
|
||||
SIYUAN_URL: "http://127.0.0.1:6806"
|
||||
```
|
||||
@@ -0,0 +1,335 @@
|
||||
---
|
||||
name: scrapling
|
||||
description: Web scraping with Scrapling - HTTP fetching, stealth browser automation, Cloudflare bypass, and spider crawling via CLI and Python.
|
||||
version: 1.0.0
|
||||
author: FEUAZUR
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Web Scraping, Browser, Cloudflare, Stealth, Crawling, Spider]
|
||||
related_skills: [duckduckgo-search, domain-intel]
|
||||
homepage: https://github.com/D4Vinci/Scrapling
|
||||
prerequisites:
|
||||
commands: [scrapling, python]
|
||||
---
|
||||
|
||||
# Scrapling
|
||||
|
||||
[Scrapling](https://github.com/D4Vinci/Scrapling) is a web scraping framework with anti-bot bypass, stealth browser automation, and a spider framework. It provides three fetching strategies (HTTP, dynamic JS, stealth/Cloudflare) and a full CLI.
|
||||
|
||||
**This skill is for educational and research purposes only.** Users must comply with local/international data scraping laws and respect website Terms of Service.
|
||||
|
||||
## When to Use
|
||||
|
||||
- Scraping static HTML pages (faster than browser tools)
|
||||
- Scraping JS-rendered pages that need a real browser
|
||||
- Bypassing Cloudflare Turnstile or bot detection
|
||||
- Crawling multiple pages with a spider
|
||||
- When the built-in `web_extract` tool does not return the data you need
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install "scrapling[all]"
|
||||
scrapling install
|
||||
```
|
||||
|
||||
Minimal install (HTTP only, no browser):
|
||||
```bash
|
||||
pip install scrapling
|
||||
```
|
||||
|
||||
With browser automation only:
|
||||
```bash
|
||||
pip install "scrapling[fetchers]"
|
||||
scrapling install
|
||||
```
|
||||
|
||||
## Quick Reference
|
||||
|
||||
| Approach | Class | Use When |
|
||||
|----------|-------|----------|
|
||||
| HTTP | `Fetcher` / `FetcherSession` | Static pages, APIs, fast bulk requests |
|
||||
| Dynamic | `DynamicFetcher` / `DynamicSession` | JS-rendered content, SPAs |
|
||||
| Stealth | `StealthyFetcher` / `StealthySession` | Cloudflare, anti-bot protected sites |
|
||||
| Spider | `Spider` | Multi-page crawling with link following |
|
||||
|
||||
## CLI Usage
|
||||
|
||||
### Extract Static Page
|
||||
|
||||
```bash
|
||||
scrapling extract get 'https://example.com' output.md
|
||||
```
|
||||
|
||||
With CSS selector and browser impersonation:
|
||||
|
||||
```bash
|
||||
scrapling extract get 'https://example.com' output.md \
|
||||
--css-selector '.content' \
|
||||
--impersonate 'chrome'
|
||||
```
|
||||
|
||||
### Extract JS-Rendered Page
|
||||
|
||||
```bash
|
||||
scrapling extract fetch 'https://example.com' output.md \
|
||||
--css-selector '.dynamic-content' \
|
||||
--disable-resources \
|
||||
--network-idle
|
||||
```
|
||||
|
||||
### Extract Cloudflare-Protected Page
|
||||
|
||||
```bash
|
||||
scrapling extract stealthy-fetch 'https://protected-site.com' output.html \
|
||||
--solve-cloudflare \
|
||||
--block-webrtc \
|
||||
--hide-canvas
|
||||
```
|
||||
|
||||
### POST Request
|
||||
|
||||
```bash
|
||||
scrapling extract post 'https://example.com/api' output.json \
|
||||
--json '{"query": "search term"}'
|
||||
```
|
||||
|
||||
### Output Formats
|
||||
|
||||
The output format is determined by the file extension:
|
||||
- `.html` -- raw HTML
|
||||
- `.md` -- converted to Markdown
|
||||
- `.txt` -- plain text
|
||||
- `.json` / `.jsonl` -- JSON
|
||||
|
||||
## Python: HTTP Scraping
|
||||
|
||||
### Single Request
|
||||
|
||||
```python
|
||||
from scrapling.fetchers import Fetcher
|
||||
|
||||
page = Fetcher.get('https://quotes.toscrape.com/')
|
||||
quotes = page.css('.quote .text::text').getall()
|
||||
for q in quotes:
|
||||
print(q)
|
||||
```
|
||||
|
||||
### Session (Persistent Cookies)
|
||||
|
||||
```python
|
||||
from scrapling.fetchers import FetcherSession
|
||||
|
||||
with FetcherSession(impersonate='chrome') as session:
|
||||
page = session.get('https://example.com/', stealthy_headers=True)
|
||||
links = page.css('a::attr(href)').getall()
|
||||
for link in links[:5]:
|
||||
sub = session.get(link)
|
||||
print(sub.css('h1::text').get())
|
||||
```
|
||||
|
||||
### POST / PUT / DELETE
|
||||
|
||||
```python
|
||||
page = Fetcher.post('https://api.example.com/data', json={"key": "value"})
|
||||
page = Fetcher.put('https://api.example.com/item/1', data={"name": "updated"})
|
||||
page = Fetcher.delete('https://api.example.com/item/1')
|
||||
```
|
||||
|
||||
### With Proxy
|
||||
|
||||
```python
|
||||
page = Fetcher.get('https://example.com', proxy='http://user:pass@proxy:8080')
|
||||
```
|
||||
|
||||
## Python: Dynamic Pages (JS-Rendered)
|
||||
|
||||
For pages that require JavaScript execution (SPAs, lazy-loaded content):
|
||||
|
||||
```python
|
||||
from scrapling.fetchers import DynamicFetcher
|
||||
|
||||
page = DynamicFetcher.fetch('https://example.com', headless=True)
|
||||
data = page.css('.js-loaded-content::text').getall()
|
||||
```
|
||||
|
||||
### Wait for Specific Element
|
||||
|
||||
```python
|
||||
page = DynamicFetcher.fetch(
|
||||
'https://example.com',
|
||||
wait_selector=('.results', 'visible'),
|
||||
network_idle=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Disable Resources for Speed
|
||||
|
||||
Blocks fonts, images, media, stylesheets (~25% faster):
|
||||
|
||||
```python
|
||||
from scrapling.fetchers import DynamicSession
|
||||
|
||||
with DynamicSession(headless=True, disable_resources=True, network_idle=True) as session:
|
||||
page = session.fetch('https://example.com')
|
||||
items = page.css('.item::text').getall()
|
||||
```
|
||||
|
||||
### Custom Page Automation
|
||||
|
||||
```python
|
||||
from playwright.sync_api import Page
|
||||
from scrapling.fetchers import DynamicFetcher
|
||||
|
||||
def scroll_and_click(page: Page):
|
||||
page.mouse.wheel(0, 3000)
|
||||
page.wait_for_timeout(1000)
|
||||
page.click('button.load-more')
|
||||
page.wait_for_selector('.extra-results')
|
||||
|
||||
page = DynamicFetcher.fetch('https://example.com', page_action=scroll_and_click)
|
||||
results = page.css('.extra-results .item::text').getall()
|
||||
```
|
||||
|
||||
## Python: Stealth Mode (Anti-Bot Bypass)
|
||||
|
||||
For Cloudflare-protected or heavily fingerprinted sites:
|
||||
|
||||
```python
|
||||
from scrapling.fetchers import StealthyFetcher
|
||||
|
||||
page = StealthyFetcher.fetch(
|
||||
'https://protected-site.com',
|
||||
headless=True,
|
||||
solve_cloudflare=True,
|
||||
block_webrtc=True,
|
||||
hide_canvas=True,
|
||||
)
|
||||
content = page.css('.protected-content::text').getall()
|
||||
```
|
||||
|
||||
### Stealth Session
|
||||
|
||||
```python
|
||||
from scrapling.fetchers import StealthySession
|
||||
|
||||
with StealthySession(headless=True, solve_cloudflare=True) as session:
|
||||
page1 = session.fetch('https://protected-site.com/page1')
|
||||
page2 = session.fetch('https://protected-site.com/page2')
|
||||
```
|
||||
|
||||
## Element Selection
|
||||
|
||||
All fetchers return a `Selector` object with these methods:
|
||||
|
||||
### CSS Selectors
|
||||
|
||||
```python
|
||||
page.css('h1::text').get() # First h1 text
|
||||
page.css('a::attr(href)').getall() # All link hrefs
|
||||
page.css('.quote .text::text').getall() # Nested selection
|
||||
```
|
||||
|
||||
### XPath
|
||||
|
||||
```python
|
||||
page.xpath('//div[@class="content"]/text()').getall()
|
||||
page.xpath('//a/@href').getall()
|
||||
```
|
||||
|
||||
### Find Methods
|
||||
|
||||
```python
|
||||
page.find_all('div', class_='quote') # By tag + attribute
|
||||
page.find_by_text('Read more', tag='a') # By text content
|
||||
page.find_by_regex(r'\$\d+\.\d{2}') # By regex pattern
|
||||
```
|
||||
|
||||
### Similar Elements
|
||||
|
||||
Find elements with similar structure (useful for product listings, etc.):
|
||||
|
||||
```python
|
||||
first_product = page.css('.product')[0]
|
||||
all_similar = first_product.find_similar()
|
||||
```
|
||||
|
||||
### Navigation
|
||||
|
||||
```python
|
||||
el = page.css('.target')[0]
|
||||
el.parent # Parent element
|
||||
el.children # Child elements
|
||||
el.next_sibling # Next sibling
|
||||
el.prev_sibling # Previous sibling
|
||||
```
|
||||
|
||||
## Python: Spider Framework
|
||||
|
||||
For multi-page crawling with link following:
|
||||
|
||||
```python
|
||||
from scrapling.spiders import Spider, Request, Response
|
||||
|
||||
class QuotesSpider(Spider):
|
||||
name = "quotes"
|
||||
start_urls = ["https://quotes.toscrape.com/"]
|
||||
concurrent_requests = 10
|
||||
download_delay = 1
|
||||
|
||||
async def parse(self, response: Response):
|
||||
for quote in response.css('.quote'):
|
||||
yield {
|
||||
"text": quote.css('.text::text').get(),
|
||||
"author": quote.css('.author::text').get(),
|
||||
"tags": quote.css('.tag::text').getall(),
|
||||
}
|
||||
|
||||
next_page = response.css('.next a::attr(href)').get()
|
||||
if next_page:
|
||||
yield response.follow(next_page)
|
||||
|
||||
result = QuotesSpider().start()
|
||||
print(f"Scraped {len(result.items)} quotes")
|
||||
result.items.to_json("quotes.json")
|
||||
```
|
||||
|
||||
### Multi-Session Spider
|
||||
|
||||
Route requests to different fetcher types:
|
||||
|
||||
```python
|
||||
from scrapling.fetchers import FetcherSession, AsyncStealthySession
|
||||
|
||||
class SmartSpider(Spider):
|
||||
name = "smart"
|
||||
start_urls = ["https://example.com/"]
|
||||
|
||||
def configure_sessions(self, manager):
|
||||
manager.add("fast", FetcherSession(impersonate="chrome"))
|
||||
manager.add("stealth", AsyncStealthySession(headless=True), lazy=True)
|
||||
|
||||
async def parse(self, response: Response):
|
||||
for link in response.css('a::attr(href)').getall():
|
||||
if "protected" in link:
|
||||
yield Request(link, sid="stealth")
|
||||
else:
|
||||
yield Request(link, sid="fast", callback=self.parse)
|
||||
```
|
||||
|
||||
### Pause/Resume Crawling
|
||||
|
||||
```python
|
||||
spider = QuotesSpider(crawldir="./crawl_checkpoint")
|
||||
spider.start() # Ctrl+C to pause, re-run to resume from checkpoint
|
||||
```
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- **Browser install required**: run `scrapling install` after pip install -- without it, `DynamicFetcher` and `StealthyFetcher` will fail
|
||||
- **Timeouts**: DynamicFetcher/StealthyFetcher timeout is in **milliseconds** (default 30000), Fetcher timeout is in **seconds**
|
||||
- **Cloudflare bypass**: `solve_cloudflare=True` adds 5-15 seconds to fetch time -- only enable when needed
|
||||
- **Resource usage**: StealthyFetcher runs a real browser -- limit concurrent usage
|
||||
- **Legal**: always check robots.txt and website ToS before scraping. This library is for educational and research purposes
|
||||
- **Python version**: requires Python 3.10+
|
||||
+142
-31
@@ -896,16 +896,30 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize OpenAI client: {e}")
|
||||
|
||||
# Provider fallback — a single backup model/provider tried when the
|
||||
# primary is exhausted (rate-limit, overload, connection failure).
|
||||
# Config shape: {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}
|
||||
self._fallback_model = fallback_model if isinstance(fallback_model, dict) else None
|
||||
# Provider fallback chain — ordered list of backup providers tried
|
||||
# when the primary is exhausted (rate-limit, overload, connection
|
||||
# failure). Supports both legacy single-dict ``fallback_model`` and
|
||||
# new list ``fallback_providers`` format.
|
||||
if isinstance(fallback_model, list):
|
||||
self._fallback_chain = [
|
||||
f for f in fallback_model
|
||||
if isinstance(f, dict) and f.get("provider") and f.get("model")
|
||||
]
|
||||
elif isinstance(fallback_model, dict) and fallback_model.get("provider") and fallback_model.get("model"):
|
||||
self._fallback_chain = [fallback_model]
|
||||
else:
|
||||
self._fallback_chain = []
|
||||
self._fallback_index = 0
|
||||
self._fallback_activated = False
|
||||
if self._fallback_model:
|
||||
fb_p = self._fallback_model.get("provider", "")
|
||||
fb_m = self._fallback_model.get("model", "")
|
||||
if fb_p and fb_m and not self.quiet_mode:
|
||||
print(f"🔄 Fallback model: {fb_m} ({fb_p})")
|
||||
# Legacy attribute kept for backward compat (tests, external callers)
|
||||
self._fallback_model = self._fallback_chain[0] if self._fallback_chain else None
|
||||
if self._fallback_chain and not self.quiet_mode:
|
||||
if len(self._fallback_chain) == 1:
|
||||
fb = self._fallback_chain[0]
|
||||
print(f"🔄 Fallback model: {fb['model']} ({fb['provider']})")
|
||||
else:
|
||||
print(f"🔄 Fallback chain ({len(self._fallback_chain)} providers): " +
|
||||
" → ".join(f"{f['model']} ({f['provider']})" for f in self._fallback_chain))
|
||||
|
||||
# Get available tools with filtering
|
||||
self.tools = get_tool_definitions(
|
||||
@@ -4318,25 +4332,26 @@ class AIAgent:
|
||||
# ── Provider fallback ──────────────────────────────────────────────────
|
||||
|
||||
def _try_activate_fallback(self) -> bool:
|
||||
"""Switch to the configured fallback model/provider.
|
||||
"""Switch to the next fallback model/provider in the chain.
|
||||
|
||||
Called when the primary model is failing after retries. Swaps the
|
||||
Called when the current model is failing after retries. Swaps the
|
||||
OpenAI client, model slug, and provider in-place so the retry loop
|
||||
can continue with the new backend. One-shot: returns False if
|
||||
already activated or not configured.
|
||||
can continue with the new backend. Advances through the chain on
|
||||
each call; returns False when exhausted.
|
||||
|
||||
Uses the centralized provider router (resolve_provider_client) for
|
||||
auth resolution and client construction — no duplicated provider→key
|
||||
mappings.
|
||||
"""
|
||||
if self._fallback_activated or not self._fallback_model:
|
||||
if self._fallback_index >= len(self._fallback_chain):
|
||||
return False
|
||||
|
||||
fb = self._fallback_model
|
||||
fb = self._fallback_chain[self._fallback_index]
|
||||
self._fallback_index += 1
|
||||
fb_provider = (fb.get("provider") or "").strip().lower()
|
||||
fb_model = (fb.get("model") or "").strip()
|
||||
if not fb_provider or not fb_model:
|
||||
return False
|
||||
return self._try_activate_fallback() # skip invalid, try next
|
||||
|
||||
# Use centralized router for client construction.
|
||||
# raw_codex=True because the main agent needs direct responses.stream()
|
||||
@@ -4349,7 +4364,7 @@ class AIAgent:
|
||||
logging.warning(
|
||||
"Fallback to %s failed: provider not configured",
|
||||
fb_provider)
|
||||
return False
|
||||
return self._try_activate_fallback() # try next in chain
|
||||
|
||||
# Determine api_mode from provider / base URL
|
||||
fb_api_mode = "chat_completions"
|
||||
@@ -4424,8 +4439,8 @@ class AIAgent:
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("Failed to activate fallback model: %s", e)
|
||||
return False
|
||||
logging.error("Failed to activate fallback %s: %s", fb_model, e)
|
||||
return self._try_activate_fallback() # try next in chain
|
||||
|
||||
# ── End provider fallback ──────────────────────────────────────────────
|
||||
|
||||
@@ -4706,9 +4721,10 @@ class AIAgent:
|
||||
api_kwargs = {
|
||||
"model": self.model,
|
||||
"messages": sanitized_messages,
|
||||
"tools": self.tools if self.tools else None,
|
||||
"timeout": float(os.getenv("HERMES_API_TIMEOUT", 1800.0)),
|
||||
}
|
||||
if self.tools:
|
||||
api_kwargs["tools"] = self.tools
|
||||
|
||||
if self.max_tokens is not None:
|
||||
api_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
@@ -6254,6 +6270,7 @@ class AIAgent:
|
||||
codex_ack_continuations = 0
|
||||
length_continue_retries = 0
|
||||
truncated_response_prefix = ""
|
||||
truncated_tool_call_count = 0
|
||||
compression_attempts = 0
|
||||
|
||||
# Clear any stale interrupt state at start
|
||||
@@ -6418,6 +6435,11 @@ class AIAgent:
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
api_kwargs = self._build_api_kwargs(api_messages)
|
||||
# Feature: Temporarily disable tools after repeated truncations
|
||||
if getattr(self, '_tools_temporarily_disabled', False):
|
||||
api_kwargs['tools'] = None
|
||||
self._tools_temporarily_disabled = False
|
||||
self._vprint(f"{self.log_prefix}ℹ️ Tools temporarily disabled for this call")
|
||||
if self.api_mode == "codex_responses":
|
||||
api_kwargs = self._preflight_codex_api_kwargs(api_kwargs, allow_stream=False)
|
||||
|
||||
@@ -6528,9 +6550,9 @@ class AIAgent:
|
||||
# Eager fallback: empty/malformed responses are a common
|
||||
# rate-limit symptom. Switch to fallback immediately
|
||||
# rather than retrying with extended backoff.
|
||||
if not self._fallback_activated:
|
||||
if self._fallback_index < len(self._fallback_chain):
|
||||
self._emit_status("⚠️ Empty/malformed response — switching to fallback...")
|
||||
if not self._fallback_activated and self._try_activate_fallback():
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
|
||||
@@ -6681,6 +6703,46 @@ class AIAgent:
|
||||
|
||||
if self.api_mode == "chat_completions":
|
||||
assistant_message = response.choices[0].message
|
||||
if assistant_message.tool_calls:
|
||||
# Feature: Discard truncated tool calls (Ironclaw #1632)
|
||||
# When finish_reason=length with tool_calls, the calls
|
||||
# are likely truncated (incomplete JSON). Discard them.
|
||||
truncated_tool_call_count += 1
|
||||
tc_count = len(assistant_message.tool_calls)
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Discarding {tc_count} truncated tool call(s) "
|
||||
f"(finish_reason='length', consecutive={truncated_tool_call_count})",
|
||||
force=True,
|
||||
)
|
||||
# Save any text content that preceded the truncated calls
|
||||
partial_content = assistant_message.content or ""
|
||||
if partial_content:
|
||||
truncated_response_prefix += partial_content
|
||||
# Build message WITHOUT tool_calls
|
||||
assistant_message.tool_calls = None
|
||||
interim_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||
messages.append(interim_msg)
|
||||
|
||||
truncation_nudge = (
|
||||
'Your previous response was truncated due to context length limits. '
|
||||
'The tool calls were discarded. Please summarize your progress so '
|
||||
'far and continue with a shorter response.'
|
||||
)
|
||||
messages.append({"role": "user", "content": truncation_nudge})
|
||||
|
||||
# After 3 consecutive truncations, temporarily disable tools
|
||||
if truncated_tool_call_count >= 3:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ 3 consecutive truncations with tool calls — "
|
||||
f"temporarily disabling tools for next call",
|
||||
force=True,
|
||||
)
|
||||
self._tools_temporarily_disabled = True
|
||||
|
||||
self._session_messages = messages
|
||||
self._save_session_log(messages)
|
||||
continue
|
||||
|
||||
if not assistant_message.tool_calls:
|
||||
length_continue_retries += 1
|
||||
interim_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||
@@ -6924,8 +6986,10 @@ class AIAgent:
|
||||
print(f"{self.log_prefix} Auth method: {auth_method}")
|
||||
print(f"{self.log_prefix} Token prefix: {key[:12]}..." if key and len(key) > 12 else f"{self.log_prefix} Token: (empty or short)")
|
||||
print(f"{self.log_prefix} Troubleshooting:")
|
||||
print(f"{self.log_prefix} • Check ANTHROPIC_TOKEN in ~/.hermes/.env for Hermes-managed OAuth/setup tokens")
|
||||
print(f"{self.log_prefix} • Check ANTHROPIC_API_KEY in ~/.hermes/.env for API keys or legacy token values")
|
||||
from hermes_constants import display_hermes_home as _dhh_fn
|
||||
_dhh = _dhh_fn()
|
||||
print(f"{self.log_prefix} • Check ANTHROPIC_TOKEN in {_dhh}/.env for Hermes-managed OAuth/setup tokens")
|
||||
print(f"{self.log_prefix} • Check ANTHROPIC_API_KEY in {_dhh}/.env for API keys or legacy token values")
|
||||
print(f"{self.log_prefix} • For API keys: verify at https://console.anthropic.com/settings/keys")
|
||||
print(f"{self.log_prefix} • For Claude Code: run 'claude /login' to refresh, then retry")
|
||||
print(f"{self.log_prefix} • Clear stale keys: hermes config set ANTHROPIC_TOKEN \"\"")
|
||||
@@ -6991,7 +7055,7 @@ class AIAgent:
|
||||
or "usage limit" in error_msg
|
||||
or "quota" in error_msg
|
||||
)
|
||||
if is_rate_limited and not self._fallback_activated:
|
||||
if is_rate_limited and self._fallback_index < len(self._fallback_chain):
|
||||
self._emit_status("⚠️ Rate limited — switching to fallback provider...")
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
@@ -7227,7 +7291,10 @@ class AIAgent:
|
||||
retry_count = 0
|
||||
continue
|
||||
_final_summary = self._summarize_api_error(api_error)
|
||||
self._vprint(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.", force=True)
|
||||
if is_rate_limited:
|
||||
self._vprint(f"{self.log_prefix}❌ Rate limit persisted after {max_retries} retries. Please try again later.", force=True)
|
||||
else:
|
||||
self._vprint(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.", force=True)
|
||||
self._vprint(f"{self.log_prefix} 💀 Final error: {_final_summary}", force=True)
|
||||
|
||||
# Detect SSE stream-drop pattern (e.g. "Network
|
||||
@@ -7287,8 +7354,22 @@ class AIAgent:
|
||||
"error": _final_summary,
|
||||
}
|
||||
|
||||
wait_time = min(2 ** retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s
|
||||
self._emit_status(f"⏳ Retrying in {wait_time}s (attempt {retry_count}/{max_retries})...")
|
||||
# For rate limits, respect the Retry-After header if present
|
||||
_retry_after = None
|
||||
if is_rate_limited:
|
||||
_resp_headers = getattr(getattr(api_error, "response", None), "headers", None)
|
||||
if _resp_headers and hasattr(_resp_headers, "get"):
|
||||
_ra_raw = _resp_headers.get("retry-after") or _resp_headers.get("Retry-After")
|
||||
if _ra_raw:
|
||||
try:
|
||||
_retry_after = min(int(_ra_raw), 120) # Cap at 2 minutes
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
wait_time = _retry_after if _retry_after else min(2 ** retry_count, 60)
|
||||
if is_rate_limited:
|
||||
self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...")
|
||||
else:
|
||||
self._emit_status(f"⏳ Retrying in {wait_time}s (attempt {retry_count}/{max_retries})...")
|
||||
logger.warning(
|
||||
"Retrying API call in %ss (attempt %s/%s) %s error=%s",
|
||||
wait_time,
|
||||
@@ -7483,6 +7564,8 @@ class AIAgent:
|
||||
|
||||
# Check for tool calls
|
||||
if assistant_message.tool_calls:
|
||||
# Reset truncated tool call counter on successful (non-truncated) tool calls
|
||||
truncated_tool_call_count = 0
|
||||
if not self.quiet_mode:
|
||||
self._vprint(f"{self.log_prefix}🔧 Processing {len(assistant_message.tool_calls)} tool call(s)...")
|
||||
|
||||
@@ -7758,11 +7841,39 @@ class AIAgent:
|
||||
content_preview = final_response[:80] + "..." if len(final_response) > 80 else final_response
|
||||
self._vprint(f"{self.log_prefix} Content: '{content_preview}'")
|
||||
|
||||
if self._empty_content_retries < 3:
|
||||
self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._empty_content_retries}/3)...")
|
||||
if self._empty_content_retries < 2:
|
||||
self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._empty_content_retries}/2)...")
|
||||
# Feature: Empty response recovery (Ironclaw #1677 + #1720)
|
||||
# On first empty retry, check for prior meaningful output
|
||||
if self._empty_content_retries == 1:
|
||||
_has_prior_output = any(
|
||||
isinstance(m, dict)
|
||||
and m.get("role") == "assistant"
|
||||
and m.get("content")
|
||||
and self._has_content_after_think_block(m["content"])
|
||||
for m in messages
|
||||
)
|
||||
if _has_prior_output:
|
||||
# Model already produced output earlier; treat as completion
|
||||
self._vprint(f"{self.log_prefix}ℹ️ Prior meaningful output exists — treating empty response as completion")
|
||||
for m in reversed(messages):
|
||||
if (isinstance(m, dict) and m.get("role") == "assistant"
|
||||
and m.get("content") and self._has_content_after_think_block(m["content"])):
|
||||
final_response = self._strip_think_blocks(m["content"]).strip()
|
||||
break
|
||||
if final_response:
|
||||
self._empty_content_retries = 0
|
||||
break
|
||||
else:
|
||||
# No prior output — inject a nudge to help the model
|
||||
nudge_msg = {
|
||||
"role": "user",
|
||||
"content": "Your previous response was empty. Please continue with the task.",
|
||||
}
|
||||
messages.append(nudge_msg)
|
||||
continue
|
||||
else:
|
||||
self._vprint(f"{self.log_prefix}❌ Max retries (3) for empty content exceeded.", force=True)
|
||||
self._vprint(f"{self.log_prefix}❌ Max retries (2) for empty content exceeded.", force=True)
|
||||
self._empty_content_retries = 0
|
||||
|
||||
# If a prior tool_calls turn had real content, salvage it:
|
||||
|
||||
@@ -4,6 +4,11 @@ description: Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration via
|
||||
version: 1.0.0
|
||||
author: Nous Research
|
||||
license: MIT
|
||||
required_credential_files:
|
||||
- path: google_token.json
|
||||
description: Google OAuth2 token (created by setup script)
|
||||
- path: google_client_secret.json
|
||||
description: Google OAuth2 client credentials (downloaded from Google Cloud Console)
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Google, Gmail, Calendar, Drive, Sheets, Docs, Contacts, Email, OAuth]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
name: duckduckgo-search
|
||||
description: Free web search via DuckDuckGo — text, news, images, videos. No API key needed. Use the Python DDGS library or CLI to search, then web_extract for full content.
|
||||
version: 1.2.0
|
||||
description: Free web search via DuckDuckGo — text, news, images, videos. No API key needed. Prefer the `ddgs` CLI when installed; use the Python DDGS library only after verifying that `ddgs` is available in the current runtime.
|
||||
version: 1.3.0
|
||||
author: gamedevCloudy
|
||||
license: MIT
|
||||
metadata:
|
||||
@@ -9,26 +9,96 @@ metadata:
|
||||
tags: [search, duckduckgo, web-search, free, fallback]
|
||||
related_skills: [arxiv]
|
||||
fallback_for_toolsets: [web]
|
||||
prerequisites:
|
||||
commands: [ddgs]
|
||||
---
|
||||
|
||||
# DuckDuckGo Search
|
||||
|
||||
Free web search using DuckDuckGo. **No API key required.**
|
||||
|
||||
Preferred when `web_search` tool is unavailable or unsuitable (no `FIRECRAWL_API_KEY` set). Can also be used as a standalone search tool.
|
||||
Preferred when `web_search` is unavailable or unsuitable (for example when `FIRECRAWL_API_KEY` is not set). Can also be used as a standalone search path when DuckDuckGo results are specifically desired.
|
||||
|
||||
## Setup
|
||||
## Detection Flow
|
||||
|
||||
Check what is actually available before choosing an approach:
|
||||
|
||||
```bash
|
||||
# Install the ddgs package (one-time)
|
||||
pip install ddgs
|
||||
# Check CLI availability
|
||||
command -v ddgs >/dev/null && echo "DDGS_CLI=installed" || echo "DDGS_CLI=missing"
|
||||
```
|
||||
|
||||
## Python API (Primary)
|
||||
Decision tree:
|
||||
1. If `ddgs` CLI is installed, prefer `terminal` + `ddgs`
|
||||
2. If `ddgs` CLI is missing, do not assume `execute_code` can import `ddgs`
|
||||
3. If the user wants DuckDuckGo specifically, install `ddgs` first in the relevant environment
|
||||
4. Otherwise fall back to built-in web/browser tools
|
||||
|
||||
Use the `DDGS` class in `execute_code` for structured results with typed fields.
|
||||
Important runtime note:
|
||||
- Terminal and `execute_code` are separate runtimes
|
||||
- A successful shell install does not guarantee `execute_code` can import `ddgs`
|
||||
- Never assume third-party Python packages are preinstalled inside `execute_code`
|
||||
|
||||
## Installation
|
||||
|
||||
Install `ddgs` only when DuckDuckGo search is specifically needed and the runtime does not already provide it.
|
||||
|
||||
```bash
|
||||
# Python package + CLI entrypoint
|
||||
pip install ddgs
|
||||
|
||||
# Verify CLI
|
||||
ddgs --help
|
||||
```
|
||||
|
||||
If a workflow depends on Python imports, verify that same runtime can import `ddgs` before using `from ddgs import DDGS`.
|
||||
|
||||
## Method 1: CLI Search (Preferred)
|
||||
|
||||
Use the `ddgs` command via `terminal` when it exists. This is the preferred path because it avoids assuming the `execute_code` sandbox has the `ddgs` Python package installed.
|
||||
|
||||
```bash
|
||||
# Text search
|
||||
ddgs text -k "python async programming" -m 5
|
||||
|
||||
# News search
|
||||
ddgs news -k "artificial intelligence" -m 5
|
||||
|
||||
# Image search
|
||||
ddgs images -k "landscape photography" -m 10
|
||||
|
||||
# Video search
|
||||
ddgs videos -k "python tutorial" -m 5
|
||||
|
||||
# With region filter
|
||||
ddgs text -k "best restaurants" -m 5 -r us-en
|
||||
|
||||
# Recent results only (d=day, w=week, m=month, y=year)
|
||||
ddgs text -k "latest AI news" -m 5 -t w
|
||||
|
||||
# JSON output for parsing
|
||||
ddgs text -k "fastapi tutorial" -m 5 -o json
|
||||
```
|
||||
|
||||
### CLI Flags
|
||||
|
||||
| Flag | Description | Example |
|
||||
|------|-------------|---------|
|
||||
| `-k` | Keywords (query) — **required** | `-k "search terms"` |
|
||||
| `-m` | Max results | `-m 5` |
|
||||
| `-r` | Region | `-r us-en` |
|
||||
| `-t` | Time limit | `-t w` (week) |
|
||||
| `-s` | Safe search | `-s off` |
|
||||
| `-o` | Output format | `-o json` |
|
||||
|
||||
## Method 2: Python API (Only After Verification)
|
||||
|
||||
Use the `DDGS` class in `execute_code` or another Python runtime only after verifying that `ddgs` is installed there. Do not assume `execute_code` includes third-party packages by default.
|
||||
|
||||
Safe wording:
|
||||
- "Use `execute_code` with `ddgs` after installing or verifying the package if needed"
|
||||
|
||||
Avoid saying:
|
||||
- "`execute_code` includes `ddgs`"
|
||||
- "DuckDuckGo search works by default in `execute_code`"
|
||||
|
||||
**Important:** `max_results` must always be passed as a **keyword argument** — positional usage raises an error on all methods.
|
||||
|
||||
@@ -76,7 +146,7 @@ from ddgs import DDGS
|
||||
with DDGS() as ddgs:
|
||||
for r in ddgs.images("semiconductor chip", max_results=5):
|
||||
print(r["title"])
|
||||
print(r["image"]) # direct image URL
|
||||
print(r["image"])
|
||||
print(r.get("thumbnail", ""))
|
||||
print(r.get("source", ""))
|
||||
print()
|
||||
@@ -94,9 +164,9 @@ from ddgs import DDGS
|
||||
with DDGS() as ddgs:
|
||||
for r in ddgs.videos("FastAPI tutorial", max_results=5):
|
||||
print(r["title"])
|
||||
print(r.get("content", "")) # video URL
|
||||
print(r.get("duration", "")) # e.g. "26:03"
|
||||
print(r.get("provider", "")) # YouTube, etc.
|
||||
print(r.get("content", ""))
|
||||
print(r.get("duration", ""))
|
||||
print(r.get("provider", ""))
|
||||
print(r.get("published", ""))
|
||||
print()
|
||||
```
|
||||
@@ -112,50 +182,17 @@ Returns: `title`, `content`, `description`, `duration`, `provider`, `published`,
|
||||
| `images()` | Visuals, diagrams | title, image, thumbnail, url |
|
||||
| `videos()` | Tutorials, demos | title, content, duration, provider |
|
||||
|
||||
## CLI (Alternative)
|
||||
|
||||
Use the `ddgs` command via terminal when you don't need structured field access.
|
||||
|
||||
```bash
|
||||
# Text search
|
||||
ddgs text -k "python async programming" -m 5
|
||||
|
||||
# News search
|
||||
ddgs news -k "artificial intelligence" -m 5
|
||||
|
||||
# Image search
|
||||
ddgs images -k "landscape photography" -m 10
|
||||
|
||||
# Video search
|
||||
ddgs videos -k "python tutorial" -m 5
|
||||
|
||||
# With region filter
|
||||
ddgs text -k "best restaurants" -m 5 -r us-en
|
||||
|
||||
# Recent results only (d=day, w=week, m=month, y=year)
|
||||
ddgs text -k "latest AI news" -m 5 -t w
|
||||
|
||||
# JSON output for parsing
|
||||
ddgs text -k "fastapi tutorial" -m 5 -o json
|
||||
```
|
||||
|
||||
### CLI Flags
|
||||
|
||||
| Flag | Description | Example |
|
||||
|------|-------------|---------|
|
||||
| `-k` | Keywords (query) — **required** | `-k "search terms"` |
|
||||
| `-m` | Max results | `-m 5` |
|
||||
| `-r` | Region | `-r us-en` |
|
||||
| `-t` | Time limit | `-t w` (week) |
|
||||
| `-s` | Safe search | `-s off` |
|
||||
| `-o` | Output format | `-o json` |
|
||||
|
||||
## Workflow: Search then Extract
|
||||
|
||||
DuckDuckGo returns titles, URLs, and snippets — not full page content. To get full content, follow up with `web_extract`:
|
||||
DuckDuckGo returns titles, URLs, and snippets — not full page content. To get full page content, search first and then extract the most relevant URL with `web_extract`, browser tools, or curl.
|
||||
|
||||
1. **Search** with ddgs to find relevant URLs
|
||||
2. **Extract** content using the `web_extract` tool (if available) or curl
|
||||
CLI example:
|
||||
|
||||
```bash
|
||||
ddgs text -k "fastapi deployment guide" -m 3 -o json
|
||||
```
|
||||
|
||||
Python example, only after verifying `ddgs` is installed in that runtime:
|
||||
|
||||
```python
|
||||
from ddgs import DDGS
|
||||
@@ -164,25 +201,37 @@ with DDGS() as ddgs:
|
||||
results = list(ddgs.text("fastapi deployment guide", max_results=3))
|
||||
for r in results:
|
||||
print(r["title"], "->", r["href"])
|
||||
|
||||
# Then use web_extract tool on the best URL
|
||||
```
|
||||
|
||||
Then extract the best URL with `web_extract` or another content-retrieval tool.
|
||||
|
||||
## Limitations
|
||||
|
||||
- **Rate limiting**: DuckDuckGo may throttle after many rapid requests. Add a short delay between searches if needed.
|
||||
- **No content extraction**: ddgs returns snippets, not full page content. Use `web_extract` or curl for that.
|
||||
- **No content extraction**: `ddgs` returns snippets, not full page content. Use `web_extract`, browser tools, or curl for the full article/page.
|
||||
- **Results quality**: Generally good but less configurable than Firecrawl's search.
|
||||
- **Availability**: DuckDuckGo may block requests from some cloud IPs. If searches return empty, try different keywords or wait a few seconds.
|
||||
- **Field variability**: Return fields may vary between results or ddgs versions. Use `.get()` for optional fields to avoid KeyError.
|
||||
- **Field variability**: Return fields may vary between results or `ddgs` versions. Use `.get()` for optional fields to avoid `KeyError`.
|
||||
- **Separate runtimes**: A successful `ddgs` install in terminal does not automatically mean `execute_code` can import it.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Likely Cause | What To Do |
|
||||
|---------|--------------|------------|
|
||||
| `ddgs: command not found` | CLI not installed in the shell environment | Install `ddgs`, or use built-in web/browser tools instead |
|
||||
| `ModuleNotFoundError: No module named 'ddgs'` | Python runtime does not have the package installed | Do not use Python DDGS there until that runtime is prepared |
|
||||
| Search returns nothing | Temporary rate limiting or poor query | Wait a few seconds, retry, or adjust the query |
|
||||
| CLI works but `execute_code` import fails | Terminal and `execute_code` are different runtimes | Keep using CLI, or separately prepare the Python runtime |
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- **`max_results` is keyword-only**: `ddgs.text("query", 5)` raises an error. Use `ddgs.text("query", max_results=5)`.
|
||||
- **Do not assume the CLI exists**: Check `command -v ddgs` before using it.
|
||||
- **Do not assume `execute_code` can import `ddgs`**: `from ddgs import DDGS` may fail with `ModuleNotFoundError` unless that runtime was prepared separately.
|
||||
- **Package name**: The package is `ddgs` (previously `duckduckgo-search`). Install with `pip install ddgs`.
|
||||
- **Don't confuse `-k` and `-m`** (CLI): `-k` is for keywords, `-m` is for max results count.
|
||||
- **Package name**: The package is `ddgs` (was previously `duckduckgo-search`). Install with `pip install ddgs`.
|
||||
- **Empty results**: If ddgs returns nothing, it may be rate-limited. Wait a few seconds and retry.
|
||||
- **Empty results**: If `ddgs` returns nothing, it may be rate-limited. Wait a few seconds and retry.
|
||||
|
||||
## Validated With
|
||||
|
||||
Smoke-tested with `ddgs==9.11.2` on Python 3.13. All four methods (text, news, images, videos) confirmed working with keyword `max_results`.
|
||||
Validated examples against `ddgs==9.11.2` semantics. Skill guidance now treats CLI availability and Python import availability as separate concerns so the documented workflow matches actual runtime behavior.
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Tests for acp_adapter.entry startup wiring."""
|
||||
|
||||
import acp
|
||||
|
||||
from acp_adapter import entry
|
||||
|
||||
|
||||
def test_main_enables_unstable_protocol(monkeypatch):
|
||||
calls = {}
|
||||
|
||||
async def fake_run_agent(agent, **kwargs):
|
||||
calls["kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(entry, "_setup_logging", lambda: None)
|
||||
monkeypatch.setattr(entry, "_load_env", lambda: None)
|
||||
monkeypatch.setattr(acp, "run_agent", fake_run_agent)
|
||||
|
||||
entry.main()
|
||||
|
||||
assert calls["kwargs"]["use_unstable_protocol"] is True
|
||||
@@ -8,6 +8,7 @@ from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
import acp
|
||||
from acp.agent.router import build_agent_router
|
||||
from acp.schema import (
|
||||
AgentCapabilities,
|
||||
AuthenticateResponse,
|
||||
@@ -18,6 +19,8 @@ from acp.schema import (
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
SetSessionConfigOptionResponse,
|
||||
SetSessionModeResponse,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
Usage,
|
||||
@@ -168,6 +171,74 @@ class TestListAndFork:
|
||||
assert fork_resp.session_id != new_resp.session_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session configuration / model routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionConfiguration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_session_mode_returns_response(self, agent):
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
resp = await agent.set_session_mode(mode_id="chat", session_id=new_resp.session_id)
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, SetSessionModeResponse)
|
||||
assert getattr(state, "mode", None) == "chat"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_config_option_returns_response(self, agent):
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
resp = await agent.set_config_option(
|
||||
config_id="approval_mode",
|
||||
session_id=new_resp.session_id,
|
||||
value="auto",
|
||||
)
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, SetSessionConfigOptionResponse)
|
||||
assert getattr(state, "config_options", {}) == {"approval_mode": "auto"}
|
||||
assert resp.config_options == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_accepts_stable_session_config_methods(self, agent):
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
router = build_agent_router(agent)
|
||||
|
||||
mode_result = await router(
|
||||
"session/set_mode",
|
||||
{"modeId": "chat", "sessionId": new_resp.session_id},
|
||||
False,
|
||||
)
|
||||
config_result = await router(
|
||||
"session/set_config_option",
|
||||
{
|
||||
"configId": "approval_mode",
|
||||
"sessionId": new_resp.session_id,
|
||||
"value": "auto",
|
||||
},
|
||||
False,
|
||||
)
|
||||
|
||||
assert mode_result == {}
|
||||
assert config_result == {"configOptions": []}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_accepts_unstable_model_switch_when_enabled(self, agent):
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
router = build_agent_router(agent, use_unstable_protocol=True)
|
||||
|
||||
result = await router(
|
||||
"session/set_model",
|
||||
{"modelId": "gpt-5.4", "sessionId": new_resp.session_id},
|
||||
False,
|
||||
)
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
assert result == {}
|
||||
assert state.model == "gpt-5.4"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Tests for external skill directories (skills.external_dirs config)."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def external_skills_dir(tmp_path):
|
||||
"""Create a temp dir with a sample external skill."""
|
||||
ext_dir = tmp_path / "external-skills"
|
||||
skill_dir = ext_dir / "my-external-skill"
|
||||
skill_dir.mkdir(parents=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: my-external-skill\ndescription: A skill from an external directory\n---\n\n# My External Skill\n\nDo external things.\n"
|
||||
)
|
||||
return ext_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_home(tmp_path):
|
||||
"""Create a minimal HERMES_HOME with config."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
(home / "skills").mkdir()
|
||||
return home
|
||||
|
||||
|
||||
class TestGetExternalSkillsDirs:
|
||||
def test_empty_config(self, hermes_home):
|
||||
(hermes_home / "config.yaml").write_text("skills:\n external_dirs: []\n")
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
from agent.skill_utils import get_external_skills_dirs
|
||||
result = get_external_skills_dirs()
|
||||
assert result == []
|
||||
|
||||
def test_nonexistent_dir_skipped(self, hermes_home):
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"skills:\n external_dirs:\n - /nonexistent/path\n"
|
||||
)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
from agent.skill_utils import get_external_skills_dirs
|
||||
result = get_external_skills_dirs()
|
||||
assert result == []
|
||||
|
||||
def test_valid_dir_returned(self, hermes_home, external_skills_dir):
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
f"skills:\n external_dirs:\n - {external_skills_dir}\n"
|
||||
)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
from agent.skill_utils import get_external_skills_dirs
|
||||
result = get_external_skills_dirs()
|
||||
assert len(result) == 1
|
||||
assert result[0] == external_skills_dir.resolve()
|
||||
|
||||
def test_duplicate_dirs_deduplicated(self, hermes_home, external_skills_dir):
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
f"skills:\n external_dirs:\n - {external_skills_dir}\n - {external_skills_dir}\n"
|
||||
)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
from agent.skill_utils import get_external_skills_dirs
|
||||
result = get_external_skills_dirs()
|
||||
assert len(result) == 1
|
||||
|
||||
def test_local_skills_dir_excluded(self, hermes_home):
|
||||
local_skills = hermes_home / "skills"
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
f"skills:\n external_dirs:\n - {local_skills}\n"
|
||||
)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
from agent.skill_utils import get_external_skills_dirs
|
||||
result = get_external_skills_dirs()
|
||||
assert result == []
|
||||
|
||||
def test_no_config_file(self, hermes_home):
|
||||
# No config.yaml at all
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
from agent.skill_utils import get_external_skills_dirs
|
||||
result = get_external_skills_dirs()
|
||||
assert result == []
|
||||
|
||||
def test_string_value_converted_to_list(self, hermes_home, external_skills_dir):
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
f"skills:\n external_dirs: {external_skills_dir}\n"
|
||||
)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
from agent.skill_utils import get_external_skills_dirs
|
||||
result = get_external_skills_dirs()
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestGetAllSkillsDirs:
|
||||
def test_local_always_first(self, hermes_home, external_skills_dir):
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
f"skills:\n external_dirs:\n - {external_skills_dir}\n"
|
||||
)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
from agent.skill_utils import get_all_skills_dirs
|
||||
result = get_all_skills_dirs()
|
||||
assert result[0] == hermes_home / "skills"
|
||||
assert result[1] == external_skills_dir.resolve()
|
||||
|
||||
|
||||
class TestExternalSkillsInFindAll:
|
||||
def test_external_skills_found(self, hermes_home, external_skills_dir):
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
f"skills:\n external_dirs:\n - {external_skills_dir}\n"
|
||||
)
|
||||
local_skills = hermes_home / "skills"
|
||||
with (
|
||||
patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}),
|
||||
patch("tools.skills_tool.SKILLS_DIR", local_skills),
|
||||
):
|
||||
from tools.skills_tool import _find_all_skills
|
||||
skills = _find_all_skills()
|
||||
names = [s["name"] for s in skills]
|
||||
assert "my-external-skill" in names
|
||||
|
||||
def test_local_takes_precedence(self, hermes_home, external_skills_dir):
|
||||
"""If the same skill name exists locally and externally, local wins."""
|
||||
local_skills = hermes_home / "skills"
|
||||
local_skill = local_skills / "my-external-skill"
|
||||
local_skill.mkdir(parents=True)
|
||||
(local_skill / "SKILL.md").write_text(
|
||||
"---\nname: my-external-skill\ndescription: Local version\n---\n\nLocal.\n"
|
||||
)
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
f"skills:\n external_dirs:\n - {external_skills_dir}\n"
|
||||
)
|
||||
with (
|
||||
patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}),
|
||||
patch("tools.skills_tool.SKILLS_DIR", local_skills),
|
||||
):
|
||||
from tools.skills_tool import _find_all_skills
|
||||
skills = _find_all_skills()
|
||||
matching = [s for s in skills if s["name"] == "my-external-skill"]
|
||||
assert len(matching) == 1
|
||||
assert matching[0]["description"] == "Local version"
|
||||
|
||||
|
||||
class TestExternalSkillView:
|
||||
def test_skill_view_finds_external(self, hermes_home, external_skills_dir):
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
f"skills:\n external_dirs:\n - {external_skills_dir}\n"
|
||||
)
|
||||
local_skills = hermes_home / "skills"
|
||||
with (
|
||||
patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}),
|
||||
patch("tools.skills_tool.SKILLS_DIR", local_skills),
|
||||
):
|
||||
from tools.skills_tool import skill_view
|
||||
result = json.loads(skill_view("my-external-skill"))
|
||||
assert result["success"] is True
|
||||
assert "external things" in result["content"]
|
||||
@@ -5,6 +5,8 @@ import importlib
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.prompt_builder import (
|
||||
_scan_context_content,
|
||||
_truncate_content,
|
||||
@@ -194,7 +196,7 @@ class TestParseSkillFile:
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
is_compat, _, _ = _parse_skill_file(skill_file)
|
||||
assert is_compat is False
|
||||
@@ -234,9 +236,6 @@ class TestPromptBuilderImports:
|
||||
# =========================================================================
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestBuildSkillsSystemPrompt:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_skills_cache(self):
|
||||
@@ -296,7 +295,7 @@ class TestBuildSkillsSystemPrompt:
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
result = build_skills_system_prompt()
|
||||
|
||||
@@ -574,6 +573,10 @@ class TestBuildContextFilesPrompt:
|
||||
result = build_context_files_prompt(cwd=str(tmp_path))
|
||||
assert "Lowercase claude rules" in result
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "darwin",
|
||||
reason="APFS default volume is case-insensitive; CLAUDE.md and claude.md alias the same path",
|
||||
)
|
||||
def test_claude_md_uppercase_takes_priority(self, tmp_path):
|
||||
(tmp_path / "CLAUDE.md").write_text("From uppercase.")
|
||||
(tmp_path / "claude.md").write_text("From lowercase.")
|
||||
|
||||
@@ -246,20 +246,10 @@ Generate some audio.
|
||||
def test_preserves_remaining_remote_setup_warning(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_ENV", "ssh")
|
||||
monkeypatch.delenv("TENOR_API_KEY", raising=False)
|
||||
|
||||
def fake_secret_callback(var_name, prompt, metadata=None):
|
||||
os.environ[var_name] = "stored-in-test"
|
||||
return {
|
||||
"success": True,
|
||||
"stored_as": var_name,
|
||||
"validated": False,
|
||||
"skipped": False,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
skills_tool_module,
|
||||
"_secret_capture_callback",
|
||||
fake_secret_callback,
|
||||
None,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -167,6 +167,32 @@ class TestDeliverResultWrapping:
|
||||
sent_content = send_mock.call_args.kwargs.get("content") or send_mock.call_args[0][-1]
|
||||
assert "Cronjob Response: abc-123" in sent_content
|
||||
|
||||
def test_delivery_skips_wrapping_when_config_disabled(self):
|
||||
"""When cron.wrap_response is false, deliver raw content without header/footer."""
|
||||
from gateway.config import Platform
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})) as send_mock, \
|
||||
patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}):
|
||||
job = {
|
||||
"id": "test-job",
|
||||
"name": "daily-report",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "123"},
|
||||
}
|
||||
_deliver_result(job, "Clean output only.")
|
||||
|
||||
send_mock.assert_called_once()
|
||||
sent_content = send_mock.call_args.kwargs.get("content") or send_mock.call_args[0][-1]
|
||||
assert sent_content == "Clean output only."
|
||||
assert "Cronjob Response" not in sent_content
|
||||
assert "The agent cannot see" not in sent_content
|
||||
|
||||
def test_no_mirror_to_session_call(self):
|
||||
"""Cron deliveries should NOT mirror into the gateway session."""
|
||||
from gateway.config import Platform
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""Tests for gateway configuration management."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from gateway.config import (
|
||||
GatewayConfig,
|
||||
HomeChannel,
|
||||
Platform,
|
||||
PlatformConfig,
|
||||
SessionResetPolicy,
|
||||
_apply_env_overrides,
|
||||
load_gateway_config,
|
||||
)
|
||||
|
||||
@@ -192,3 +196,75 @@ class TestLoadGatewayConfig:
|
||||
|
||||
assert config.unauthorized_dm_behavior == "ignore"
|
||||
assert config.platforms[Platform.WHATSAPP].extra["unauthorized_dm_behavior"] == "pair"
|
||||
|
||||
|
||||
class TestHomeChannelEnvOverrides:
|
||||
"""Home channel env vars should apply even when the platform was already
|
||||
configured via config.yaml (not just when credential env vars create it)."""
|
||||
|
||||
def test_existing_platform_configs_accept_home_channel_env_overrides(self):
|
||||
cases = [
|
||||
(
|
||||
Platform.SLACK,
|
||||
PlatformConfig(enabled=True, token="xoxb-from-config"),
|
||||
{"SLACK_HOME_CHANNEL": "C123", "SLACK_HOME_CHANNEL_NAME": "Ops"},
|
||||
("C123", "Ops"),
|
||||
),
|
||||
(
|
||||
Platform.SIGNAL,
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"http_url": "http://localhost:9090", "account": "+15551234567"},
|
||||
),
|
||||
{"SIGNAL_HOME_CHANNEL": "+1555000", "SIGNAL_HOME_CHANNEL_NAME": "Phone"},
|
||||
("+1555000", "Phone"),
|
||||
),
|
||||
(
|
||||
Platform.MATTERMOST,
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
token="mm-token",
|
||||
extra={"url": "https://mm.example.com"},
|
||||
),
|
||||
{"MATTERMOST_HOME_CHANNEL": "ch_abc123", "MATTERMOST_HOME_CHANNEL_NAME": "General"},
|
||||
("ch_abc123", "General"),
|
||||
),
|
||||
(
|
||||
Platform.MATRIX,
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_abc123",
|
||||
extra={"homeserver": "https://matrix.example.org"},
|
||||
),
|
||||
{"MATRIX_HOME_ROOM": "!room123:example.org", "MATRIX_HOME_ROOM_NAME": "Bot Room"},
|
||||
("!room123:example.org", "Bot Room"),
|
||||
),
|
||||
(
|
||||
Platform.EMAIL,
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"address": "hermes@test.com",
|
||||
"imap_host": "imap.test.com",
|
||||
"smtp_host": "smtp.test.com",
|
||||
},
|
||||
),
|
||||
{"EMAIL_HOME_ADDRESS": "user@test.com", "EMAIL_HOME_ADDRESS_NAME": "Inbox"},
|
||||
("user@test.com", "Inbox"),
|
||||
),
|
||||
(
|
||||
Platform.SMS,
|
||||
PlatformConfig(enabled=True, api_key="token_abc"),
|
||||
{"SMS_HOME_CHANNEL": "+15559876543", "SMS_HOME_CHANNEL_NAME": "My Phone"},
|
||||
("+15559876543", "My Phone"),
|
||||
),
|
||||
]
|
||||
|
||||
for platform, platform_config, env, expected in cases:
|
||||
config = GatewayConfig(platforms={platform: platform_config})
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
_apply_env_overrides(config)
|
||||
|
||||
home = config.platforms[platform].home_channel
|
||||
assert home is not None, f"{platform.value}: home_channel should not be None"
|
||||
assert (home.chat_id, home.name) == expected, platform.value
|
||||
|
||||
@@ -1057,5 +1057,122 @@ class TestSendEmailStandalone(unittest.TestCase):
|
||||
self.assertIn("not configured", result["error"])
|
||||
|
||||
|
||||
class TestSmtpConnectionCleanup(unittest.TestCase):
|
||||
"""Verify SMTP connections are closed even when send_message raises."""
|
||||
|
||||
@patch.dict(os.environ, {
|
||||
"EMAIL_ADDRESS": "hermes@test.com",
|
||||
"EMAIL_PASSWORD": "secret",
|
||||
"EMAIL_IMAP_HOST": "imap.test.com",
|
||||
"EMAIL_SMTP_HOST": "smtp.test.com",
|
||||
"EMAIL_SMTP_PORT": "587",
|
||||
}, clear=False)
|
||||
def _make_adapter(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.email import EmailAdapter
|
||||
return EmailAdapter(PlatformConfig(enabled=True))
|
||||
|
||||
@patch.dict(os.environ, {
|
||||
"EMAIL_ADDRESS": "hermes@test.com",
|
||||
"EMAIL_PASSWORD": "secret",
|
||||
"EMAIL_IMAP_HOST": "imap.test.com",
|
||||
"EMAIL_SMTP_HOST": "smtp.test.com",
|
||||
"EMAIL_SMTP_PORT": "587",
|
||||
}, clear=False)
|
||||
def test_smtp_quit_called_on_send_message_failure(self):
|
||||
"""SMTP quit() must be called even when send_message() raises."""
|
||||
adapter = self._make_adapter()
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.send_message.side_effect = Exception("send failed")
|
||||
|
||||
with patch("smtplib.SMTP", return_value=mock_smtp):
|
||||
with self.assertRaises(Exception):
|
||||
adapter._send_email("user@test.com", "Hello")
|
||||
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {
|
||||
"EMAIL_ADDRESS": "hermes@test.com",
|
||||
"EMAIL_PASSWORD": "secret",
|
||||
"EMAIL_IMAP_HOST": "imap.test.com",
|
||||
"EMAIL_SMTP_HOST": "smtp.test.com",
|
||||
"EMAIL_SMTP_PORT": "587",
|
||||
}, clear=False)
|
||||
def test_smtp_close_called_when_quit_also_fails(self):
|
||||
"""If both send_message() and quit() fail, close() is the fallback."""
|
||||
adapter = self._make_adapter()
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.send_message.side_effect = Exception("send failed")
|
||||
mock_smtp.quit.side_effect = Exception("quit failed")
|
||||
|
||||
with patch("smtplib.SMTP", return_value=mock_smtp):
|
||||
with self.assertRaises(Exception):
|
||||
adapter._send_email("user@test.com", "Hello")
|
||||
|
||||
mock_smtp.close.assert_called_once()
|
||||
|
||||
|
||||
class TestImapConnectionCleanup(unittest.TestCase):
|
||||
"""Verify IMAP connections are closed even when fetch raises."""
|
||||
|
||||
@patch.dict(os.environ, {
|
||||
"EMAIL_ADDRESS": "hermes@test.com",
|
||||
"EMAIL_PASSWORD": "secret",
|
||||
"EMAIL_IMAP_HOST": "imap.test.com",
|
||||
"EMAIL_IMAP_PORT": "993",
|
||||
"EMAIL_SMTP_HOST": "smtp.test.com",
|
||||
}, clear=False)
|
||||
def _make_adapter(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.email import EmailAdapter
|
||||
return EmailAdapter(PlatformConfig(enabled=True))
|
||||
|
||||
@patch.dict(os.environ, {
|
||||
"EMAIL_ADDRESS": "hermes@test.com",
|
||||
"EMAIL_PASSWORD": "secret",
|
||||
"EMAIL_IMAP_HOST": "imap.test.com",
|
||||
"EMAIL_IMAP_PORT": "993",
|
||||
"EMAIL_SMTP_HOST": "smtp.test.com",
|
||||
}, clear=False)
|
||||
def test_imap_logout_called_on_uid_fetch_failure(self):
|
||||
"""IMAP logout() must be called even when uid fetch raises."""
|
||||
adapter = self._make_adapter()
|
||||
mock_imap = MagicMock()
|
||||
|
||||
def uid_handler(command, *args):
|
||||
if command == "search":
|
||||
return ("OK", [b"1"])
|
||||
if command == "fetch":
|
||||
raise Exception("fetch failed")
|
||||
return ("NO", [])
|
||||
|
||||
mock_imap.uid.side_effect = uid_handler
|
||||
|
||||
with patch("imaplib.IMAP4_SSL", return_value=mock_imap):
|
||||
results = adapter._fetch_new_messages()
|
||||
|
||||
self.assertEqual(results, [])
|
||||
mock_imap.logout.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {
|
||||
"EMAIL_ADDRESS": "hermes@test.com",
|
||||
"EMAIL_PASSWORD": "secret",
|
||||
"EMAIL_IMAP_HOST": "imap.test.com",
|
||||
"EMAIL_IMAP_PORT": "993",
|
||||
"EMAIL_SMTP_HOST": "smtp.test.com",
|
||||
}, clear=False)
|
||||
def test_imap_logout_called_on_early_return(self):
|
||||
"""IMAP logout() must be called even when returning early (no unseen)."""
|
||||
adapter = self._make_adapter()
|
||||
mock_imap = MagicMock()
|
||||
mock_imap.uid.return_value = ("OK", [b""])
|
||||
|
||||
with patch("imaplib.IMAP4_SSL", return_value=mock_imap):
|
||||
results = adapter._fetch_new_messages()
|
||||
|
||||
self.assertEqual(results, [])
|
||||
mock_imap.logout.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for Mattermost platform adapter."""
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
@@ -269,6 +270,7 @@ class TestMattermostWebSocketParsing:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
self.adapter._bot_user_id = "bot_user_id"
|
||||
self.adapter._bot_username = "hermes-bot"
|
||||
# Mock handle_message to capture the MessageEvent without processing
|
||||
self.adapter.handle_message = AsyncMock()
|
||||
|
||||
@@ -293,7 +295,8 @@ class TestMattermostWebSocketParsing:
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert self.adapter.handle_message.called
|
||||
msg_event = self.adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.text == "@bot_user_id Hello from Matrix!"
|
||||
# @mention is stripped from the message text
|
||||
assert msg_event.text == "Hello from Matrix!"
|
||||
assert msg_event.message_id == "post_abc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -410,6 +413,87 @@ class TestMattermostWebSocketParsing:
|
||||
assert not self.adapter.handle_message.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mention behavior (require_mention + free_response_channels)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMattermostMentionBehavior:
|
||||
def setup_method(self):
|
||||
self.adapter = _make_adapter()
|
||||
self.adapter._bot_user_id = "bot_user_id"
|
||||
self.adapter._bot_username = "hermes-bot"
|
||||
self.adapter.handle_message = AsyncMock()
|
||||
|
||||
def _make_event(self, message, channel_type="O", channel_id="chan_456"):
|
||||
post_data = {
|
||||
"id": "post_mention",
|
||||
"user_id": "user_123",
|
||||
"channel_id": channel_id,
|
||||
"message": message,
|
||||
}
|
||||
return {
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": json.dumps(post_data),
|
||||
"channel_type": channel_type,
|
||||
"sender_name": "@alice",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_mention_true_skips_without_mention(self):
|
||||
"""Default: messages without @mention in channels are skipped."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("MATTERMOST_REQUIRE_MENTION", None)
|
||||
os.environ.pop("MATTERMOST_FREE_RESPONSE_CHANNELS", None)
|
||||
await self.adapter._handle_ws_event(self._make_event("hello"))
|
||||
assert not self.adapter.handle_message.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_mention_false_responds_to_all(self):
|
||||
"""MATTERMOST_REQUIRE_MENTION=false: respond to all channel messages."""
|
||||
with patch.dict(os.environ, {"MATTERMOST_REQUIRE_MENTION": "false"}):
|
||||
await self.adapter._handle_ws_event(self._make_event("hello"))
|
||||
assert self.adapter.handle_message.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_response_channel_responds_without_mention(self):
|
||||
"""Messages in free-response channels don't need @mention."""
|
||||
with patch.dict(os.environ, {"MATTERMOST_FREE_RESPONSE_CHANNELS": "chan_456,chan_789"}):
|
||||
os.environ.pop("MATTERMOST_REQUIRE_MENTION", None)
|
||||
await self.adapter._handle_ws_event(self._make_event("hello", channel_id="chan_456"))
|
||||
assert self.adapter.handle_message.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_free_channel_still_requires_mention(self):
|
||||
"""Channels NOT in free-response list still require @mention."""
|
||||
with patch.dict(os.environ, {"MATTERMOST_FREE_RESPONSE_CHANNELS": "chan_789"}):
|
||||
os.environ.pop("MATTERMOST_REQUIRE_MENTION", None)
|
||||
await self.adapter._handle_ws_event(self._make_event("hello", channel_id="chan_456"))
|
||||
assert not self.adapter.handle_message.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_always_responds(self):
|
||||
"""DMs (channel_type=D) always respond regardless of mention settings."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("MATTERMOST_REQUIRE_MENTION", None)
|
||||
await self.adapter._handle_ws_event(self._make_event("hello", channel_type="D"))
|
||||
assert self.adapter.handle_message.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mention_stripped_from_text(self):
|
||||
"""@mention is stripped from message text."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("MATTERMOST_REQUIRE_MENTION", None)
|
||||
await self.adapter._handle_ws_event(
|
||||
self._make_event("@hermes-bot what is 2+2")
|
||||
)
|
||||
assert self.adapter.handle_message.called
|
||||
msg = self.adapter.handle_message.call_args[0][0]
|
||||
assert "@hermes-bot" not in msg.text
|
||||
assert "2+2" in msg.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File upload (send_image)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
+102
-30
@@ -1,11 +1,42 @@
|
||||
"""Tests for Signal messenger platform adapter."""
|
||||
import base64
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from urllib.parse import quote
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_signal_adapter(monkeypatch, account="+15551234567", **extra):
|
||||
"""Create a SignalAdapter with sensible test defaults."""
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", extra.pop("group_allowed", ""))
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
config = PlatformConfig()
|
||||
config.enabled = True
|
||||
config.extra = {
|
||||
"http_url": "http://localhost:8080",
|
||||
"account": account,
|
||||
**extra,
|
||||
}
|
||||
return SignalAdapter(config)
|
||||
|
||||
|
||||
def _stub_rpc(return_value):
|
||||
"""Return an async mock for SignalAdapter._rpc that captures call params."""
|
||||
captured = []
|
||||
|
||||
async def mock_rpc(method, params, rpc_id=None):
|
||||
captured.append({"method": method, "params": dict(params)})
|
||||
return return_value
|
||||
|
||||
return mock_rpc, captured
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform & Config
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -61,48 +92,22 @@ class TestSignalConfigLoading:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalAdapterInit:
|
||||
def _make_config(self, **extra):
|
||||
config = PlatformConfig()
|
||||
config.enabled = True
|
||||
config.extra = {
|
||||
"http_url": "http://localhost:8080",
|
||||
"account": "+15551234567",
|
||||
**extra,
|
||||
}
|
||||
return config
|
||||
|
||||
def test_init_parses_config(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "group123,group456")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config())
|
||||
|
||||
adapter = _make_signal_adapter(monkeypatch, group_allowed="group123,group456")
|
||||
assert adapter.http_url == "http://localhost:8080"
|
||||
assert adapter.account == "+15551234567"
|
||||
assert "group123" in adapter.group_allow_from
|
||||
|
||||
def test_init_empty_allowlist(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config())
|
||||
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
assert len(adapter.group_allow_from) == 0
|
||||
|
||||
def test_init_strips_trailing_slash(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config(http_url="http://localhost:8080/"))
|
||||
|
||||
adapter = _make_signal_adapter(monkeypatch, http_url="http://localhost:8080/")
|
||||
assert adapter.http_url == "http://localhost:8080"
|
||||
|
||||
def test_self_message_filtering(self, monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
|
||||
from gateway.platforms.signal import SignalAdapter
|
||||
adapter = SignalAdapter(self._make_config())
|
||||
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
assert adapter._account_normalized == "+15551234567"
|
||||
|
||||
|
||||
@@ -189,6 +194,73 @@ class TestSignalHelpers:
|
||||
assert check_signal_requirements() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSE URL Encoding (Bug Fix: phone numbers with + must be URL-encoded)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalSSEUrlEncoding:
|
||||
"""Verify that phone numbers with + are URL-encoded in the SSE endpoint."""
|
||||
|
||||
def test_sse_url_encodes_plus_in_account(self):
|
||||
"""The + in E.164 phone numbers must be percent-encoded in the SSE query string."""
|
||||
encoded = quote("+31612345678", safe="")
|
||||
assert encoded == "%2B31612345678"
|
||||
|
||||
def test_sse_url_encoding_preserves_digits(self):
|
||||
"""Digits and country codes should pass through URL encoding unchanged."""
|
||||
assert quote("+15551234567", safe="") == "%2B15551234567"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attachment Fetch (Bug Fix: parameter must be "id" not "attachmentId")
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalAttachmentFetch:
|
||||
"""Verify that _fetch_attachment uses the correct RPC parameter name."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_attachment_uses_id_parameter(self, monkeypatch):
|
||||
"""RPC getAttachment must use 'id', not 'attachmentId' (signal-cli requirement)."""
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
|
||||
png_data = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||
b64_data = base64.b64encode(png_data).decode()
|
||||
|
||||
adapter._rpc, captured = _stub_rpc({"data": b64_data})
|
||||
|
||||
with patch("gateway.platforms.signal.cache_image_from_bytes", return_value="/tmp/test.png"):
|
||||
await adapter._fetch_attachment("attachment-123")
|
||||
|
||||
call = captured[0]
|
||||
assert call["method"] == "getAttachment"
|
||||
assert call["params"]["id"] == "attachment-123"
|
||||
assert "attachmentId" not in call["params"], "Must NOT use 'attachmentId' — causes NullPointerException in signal-cli"
|
||||
assert call["params"]["account"] == "+15551234567"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_attachment_returns_none_on_empty(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
adapter._rpc, _ = _stub_rpc(None)
|
||||
path, ext = await adapter._fetch_attachment("missing-id")
|
||||
assert path is None
|
||||
assert ext == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_attachment_handles_dict_response(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
|
||||
pdf_data = b"%PDF-1.4" + b"\x00" * 100
|
||||
b64_data = base64.b64encode(pdf_data).decode()
|
||||
|
||||
adapter._rpc, _ = _stub_rpc({"data": b64_data})
|
||||
|
||||
with patch("gateway.platforms.signal.cache_document_from_bytes", return_value="/tmp/test.pdf"):
|
||||
path, ext = await adapter._fetch_attachment("doc-456")
|
||||
|
||||
assert path == "/tmp/test.pdf"
|
||||
assert ext == ".pdf"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session Source
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -63,6 +63,7 @@ def _make_adapter():
|
||||
adapter._background_tasks = set()
|
||||
adapter._auto_tts_disabled_chats = set()
|
||||
adapter._message_queue = asyncio.Queue()
|
||||
adapter._http_session = None
|
||||
return adapter
|
||||
|
||||
|
||||
@@ -219,6 +220,7 @@ class TestBridgeRuntimeFailure:
|
||||
fatal_handler = AsyncMock()
|
||||
adapter.set_fatal_error_handler(fatal_handler)
|
||||
adapter._running = True
|
||||
adapter._http_session = MagicMock() # Persistent session active
|
||||
mock_fh = MagicMock()
|
||||
adapter._bridge_log_fh = mock_fh
|
||||
|
||||
@@ -242,6 +244,7 @@ class TestBridgeRuntimeFailure:
|
||||
fatal_handler = AsyncMock()
|
||||
adapter.set_fatal_error_handler(fatal_handler)
|
||||
adapter._running = True
|
||||
adapter._http_session = MagicMock() # Persistent session active
|
||||
mock_fh = MagicMock()
|
||||
adapter._bridge_log_fh = mock_fh
|
||||
|
||||
@@ -417,3 +420,83 @@ class TestKillPortProcess:
|
||||
with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \
|
||||
patch("gateway.platforms.whatsapp.subprocess.run", side_effect=OSError("no netstat")):
|
||||
_kill_port_process(3000) # must not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistent HTTP session lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHttpSessionLifecycle:
|
||||
"""Verify persistent aiohttp.ClientSession is created and cleaned up."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_closed_on_disconnect(self):
|
||||
"""disconnect() should close self._http_session."""
|
||||
adapter = _make_adapter()
|
||||
mock_session = AsyncMock()
|
||||
mock_session.closed = False
|
||||
adapter._http_session = mock_session
|
||||
adapter._poll_task = None
|
||||
adapter._bridge_process = None
|
||||
adapter._running = True
|
||||
adapter._session_lock_identity = None
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
mock_session.close.assert_called_once()
|
||||
assert adapter._http_session is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_not_closed_when_already_closed(self):
|
||||
"""disconnect() should skip close() when session is already closed."""
|
||||
adapter = _make_adapter()
|
||||
mock_session = AsyncMock()
|
||||
mock_session.closed = True
|
||||
adapter._http_session = mock_session
|
||||
adapter._poll_task = None
|
||||
adapter._bridge_process = None
|
||||
adapter._running = True
|
||||
adapter._session_lock_identity = None
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
mock_session.close.assert_not_called()
|
||||
assert adapter._http_session is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_task_cancelled_on_disconnect(self):
|
||||
"""disconnect() should cancel the poll task."""
|
||||
adapter = _make_adapter()
|
||||
mock_task = MagicMock()
|
||||
mock_task.done.return_value = False
|
||||
mock_task.cancel = MagicMock()
|
||||
mock_future = asyncio.Future()
|
||||
mock_future.set_exception(asyncio.CancelledError())
|
||||
mock_task.__await__ = mock_future.__await__
|
||||
adapter._poll_task = mock_task
|
||||
adapter._http_session = None
|
||||
adapter._bridge_process = None
|
||||
adapter._running = True
|
||||
adapter._session_lock_identity = None
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
mock_task.cancel.assert_called_once()
|
||||
assert adapter._poll_task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_skips_done_poll_task(self):
|
||||
"""disconnect() should not cancel an already-done poll task."""
|
||||
adapter = _make_adapter()
|
||||
mock_task = MagicMock()
|
||||
mock_task.done.return_value = True
|
||||
adapter._poll_task = mock_task
|
||||
adapter._http_session = None
|
||||
adapter._bridge_process = None
|
||||
adapter._running = True
|
||||
adapter._session_lock_identity = None
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
mock_task.cancel.assert_not_called()
|
||||
assert adapter._poll_task is None
|
||||
|
||||
@@ -0,0 +1,622 @@
|
||||
"""Comprehensive tests for hermes_cli.profiles module.
|
||||
|
||||
Tests cover: validation, directory resolution, CRUD operations, active profile
|
||||
management, export/import, renaming, alias collision checks, profile isolation,
|
||||
and shell completion generation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.profiles import (
|
||||
validate_profile_name,
|
||||
get_profile_dir,
|
||||
create_profile,
|
||||
delete_profile,
|
||||
list_profiles,
|
||||
set_active_profile,
|
||||
get_active_profile,
|
||||
get_active_profile_name,
|
||||
resolve_profile_env,
|
||||
check_alias_collision,
|
||||
rename_profile,
|
||||
export_profile,
|
||||
import_profile,
|
||||
generate_bash_completion,
|
||||
generate_zsh_completion,
|
||||
_get_profiles_root,
|
||||
_get_default_hermes_home,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixture: redirect Path.home() and HERMES_HOME for profile tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def profile_env(tmp_path, monkeypatch):
|
||||
"""Set up an isolated environment for profile tests.
|
||||
|
||||
* Path.home() -> tmp_path (so _get_profiles_root() = tmp_path/.hermes/profiles)
|
||||
* HERMES_HOME -> tmp_path/.hermes (so get_hermes_home() agrees)
|
||||
* Creates the bare-minimum ~/.hermes directory.
|
||||
"""
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
default_home = tmp_path / ".hermes"
|
||||
default_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(default_home))
|
||||
return tmp_path
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestValidateProfileName
|
||||
# ===================================================================
|
||||
|
||||
class TestValidateProfileName:
|
||||
"""Tests for validate_profile_name()."""
|
||||
|
||||
@pytest.mark.parametrize("name", ["coder", "work-bot", "a1", "my_agent"])
|
||||
def test_valid_names_accepted(self, name):
|
||||
# Should not raise
|
||||
validate_profile_name(name)
|
||||
|
||||
@pytest.mark.parametrize("name", ["UPPER", "has space", ".hidden", "-leading"])
|
||||
def test_invalid_names_rejected(self, name):
|
||||
with pytest.raises(ValueError):
|
||||
validate_profile_name(name)
|
||||
|
||||
def test_too_long_rejected(self):
|
||||
long_name = "a" * 65
|
||||
with pytest.raises(ValueError):
|
||||
validate_profile_name(long_name)
|
||||
|
||||
def test_max_length_accepted(self):
|
||||
# 64 chars total: 1 leading + 63 remaining = 64, within [0,63] range
|
||||
name = "a" * 64
|
||||
validate_profile_name(name)
|
||||
|
||||
def test_default_accepted(self):
|
||||
# 'default' is a special-case pass-through
|
||||
validate_profile_name("default")
|
||||
|
||||
def test_empty_string_rejected(self):
|
||||
with pytest.raises(ValueError):
|
||||
validate_profile_name("")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestGetProfileDir
|
||||
# ===================================================================
|
||||
|
||||
class TestGetProfileDir:
|
||||
"""Tests for get_profile_dir()."""
|
||||
|
||||
def test_default_returns_hermes_home(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
result = get_profile_dir("default")
|
||||
assert result == tmp_path / ".hermes"
|
||||
|
||||
def test_named_profile_returns_profiles_subdir(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
result = get_profile_dir("coder")
|
||||
assert result == tmp_path / ".hermes" / "profiles" / "coder"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestCreateProfile
|
||||
# ===================================================================
|
||||
|
||||
class TestCreateProfile:
|
||||
"""Tests for create_profile()."""
|
||||
|
||||
def test_creates_directory_with_subdirs(self, profile_env):
|
||||
profile_dir = create_profile("coder", no_alias=True)
|
||||
assert profile_dir.is_dir()
|
||||
for subdir in ["memories", "sessions", "skills", "skins", "logs",
|
||||
"plans", "workspace", "cron"]:
|
||||
assert (profile_dir / subdir).is_dir(), f"Missing subdir: {subdir}"
|
||||
|
||||
def test_duplicate_raises_file_exists(self, profile_env):
|
||||
create_profile("coder", no_alias=True)
|
||||
with pytest.raises(FileExistsError):
|
||||
create_profile("coder", no_alias=True)
|
||||
|
||||
def test_default_raises_value_error(self, profile_env):
|
||||
with pytest.raises(ValueError, match="default"):
|
||||
create_profile("default", no_alias=True)
|
||||
|
||||
def test_invalid_name_raises_value_error(self, profile_env):
|
||||
with pytest.raises(ValueError):
|
||||
create_profile("INVALID!", no_alias=True)
|
||||
|
||||
def test_clone_config_copies_files(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
default_home = tmp_path / ".hermes"
|
||||
# Create source config files in default profile
|
||||
(default_home / "config.yaml").write_text("model: test")
|
||||
(default_home / ".env").write_text("KEY=val")
|
||||
(default_home / "SOUL.md").write_text("Be helpful.")
|
||||
|
||||
profile_dir = create_profile("coder", clone_config=True, no_alias=True)
|
||||
|
||||
assert (profile_dir / "config.yaml").read_text() == "model: test"
|
||||
assert (profile_dir / ".env").read_text() == "KEY=val"
|
||||
assert (profile_dir / "SOUL.md").read_text() == "Be helpful."
|
||||
|
||||
def test_clone_all_copies_entire_tree(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
default_home = tmp_path / ".hermes"
|
||||
# Populate default with some content
|
||||
(default_home / "memories").mkdir(exist_ok=True)
|
||||
(default_home / "memories" / "note.md").write_text("remember this")
|
||||
(default_home / "config.yaml").write_text("model: gpt-4")
|
||||
# Runtime files that should be stripped
|
||||
(default_home / "gateway.pid").write_text("12345")
|
||||
(default_home / "gateway_state.json").write_text("{}")
|
||||
(default_home / "processes.json").write_text("[]")
|
||||
|
||||
profile_dir = create_profile("coder", clone_all=True, no_alias=True)
|
||||
|
||||
# Content should be copied
|
||||
assert (profile_dir / "memories" / "note.md").read_text() == "remember this"
|
||||
assert (profile_dir / "config.yaml").read_text() == "model: gpt-4"
|
||||
# Runtime files should be stripped
|
||||
assert not (profile_dir / "gateway.pid").exists()
|
||||
assert not (profile_dir / "gateway_state.json").exists()
|
||||
assert not (profile_dir / "processes.json").exists()
|
||||
|
||||
def test_clone_config_missing_files_skipped(self, profile_env):
|
||||
"""Clone config gracefully skips files that don't exist in source."""
|
||||
profile_dir = create_profile("coder", clone_config=True, no_alias=True)
|
||||
# No error; optional files just not copied
|
||||
assert not (profile_dir / "config.yaml").exists()
|
||||
assert not (profile_dir / ".env").exists()
|
||||
assert not (profile_dir / "SOUL.md").exists()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestDeleteProfile
|
||||
# ===================================================================
|
||||
|
||||
class TestDeleteProfile:
|
||||
"""Tests for delete_profile()."""
|
||||
|
||||
def test_removes_directory(self, profile_env):
|
||||
profile_dir = create_profile("coder", no_alias=True)
|
||||
assert profile_dir.is_dir()
|
||||
# Mock gateway import to avoid real systemd/launchd interaction
|
||||
with patch("hermes_cli.profiles._cleanup_gateway_service"):
|
||||
delete_profile("coder", yes=True)
|
||||
assert not profile_dir.is_dir()
|
||||
|
||||
def test_default_raises_value_error(self, profile_env):
|
||||
with pytest.raises(ValueError, match="default"):
|
||||
delete_profile("default", yes=True)
|
||||
|
||||
def test_nonexistent_raises_file_not_found(self, profile_env):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
delete_profile("nonexistent", yes=True)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestListProfiles
|
||||
# ===================================================================
|
||||
|
||||
class TestListProfiles:
|
||||
"""Tests for list_profiles()."""
|
||||
|
||||
def test_returns_default_when_no_named_profiles(self, profile_env):
|
||||
profiles = list_profiles()
|
||||
names = [p.name for p in profiles]
|
||||
assert "default" in names
|
||||
|
||||
def test_includes_named_profiles(self, profile_env):
|
||||
create_profile("alpha", no_alias=True)
|
||||
create_profile("beta", no_alias=True)
|
||||
profiles = list_profiles()
|
||||
names = [p.name for p in profiles]
|
||||
assert "alpha" in names
|
||||
assert "beta" in names
|
||||
|
||||
def test_sorted_alphabetically(self, profile_env):
|
||||
create_profile("zebra", no_alias=True)
|
||||
create_profile("alpha", no_alias=True)
|
||||
create_profile("middle", no_alias=True)
|
||||
profiles = list_profiles()
|
||||
named = [p.name for p in profiles if not p.is_default]
|
||||
assert named == sorted(named)
|
||||
|
||||
def test_default_is_first(self, profile_env):
|
||||
create_profile("alpha", no_alias=True)
|
||||
profiles = list_profiles()
|
||||
assert profiles[0].name == "default"
|
||||
assert profiles[0].is_default is True
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestActiveProfile
|
||||
# ===================================================================
|
||||
|
||||
class TestActiveProfile:
|
||||
"""Tests for set_active_profile() / get_active_profile()."""
|
||||
|
||||
def test_set_and_get_roundtrip(self, profile_env):
|
||||
create_profile("coder", no_alias=True)
|
||||
set_active_profile("coder")
|
||||
assert get_active_profile() == "coder"
|
||||
|
||||
def test_no_file_returns_default(self, profile_env):
|
||||
assert get_active_profile() == "default"
|
||||
|
||||
def test_empty_file_returns_default(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
active_path = tmp_path / ".hermes" / "active_profile"
|
||||
active_path.write_text("")
|
||||
assert get_active_profile() == "default"
|
||||
|
||||
def test_set_to_default_removes_file(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
create_profile("coder", no_alias=True)
|
||||
set_active_profile("coder")
|
||||
active_path = tmp_path / ".hermes" / "active_profile"
|
||||
assert active_path.exists()
|
||||
|
||||
set_active_profile("default")
|
||||
assert not active_path.exists()
|
||||
|
||||
def test_set_nonexistent_raises(self, profile_env):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
set_active_profile("nonexistent")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestGetActiveProfileName
|
||||
# ===================================================================
|
||||
|
||||
class TestGetActiveProfileName:
|
||||
"""Tests for get_active_profile_name()."""
|
||||
|
||||
def test_default_hermes_home_returns_default(self, profile_env):
|
||||
# HERMES_HOME points to tmp_path/.hermes which is the default
|
||||
assert get_active_profile_name() == "default"
|
||||
|
||||
def test_profile_path_returns_profile_name(self, profile_env, monkeypatch):
|
||||
tmp_path = profile_env
|
||||
create_profile("coder", no_alias=True)
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "coder"
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
assert get_active_profile_name() == "coder"
|
||||
|
||||
def test_custom_path_returns_custom(self, profile_env, monkeypatch):
|
||||
tmp_path = profile_env
|
||||
custom = tmp_path / "some" / "other" / "path"
|
||||
custom.mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(custom))
|
||||
assert get_active_profile_name() == "custom"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestResolveProfileEnv
|
||||
# ===================================================================
|
||||
|
||||
class TestResolveProfileEnv:
|
||||
"""Tests for resolve_profile_env()."""
|
||||
|
||||
def test_existing_profile_returns_path(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
create_profile("coder", no_alias=True)
|
||||
result = resolve_profile_env("coder")
|
||||
assert result == str(tmp_path / ".hermes" / "profiles" / "coder")
|
||||
|
||||
def test_default_returns_default_home(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
result = resolve_profile_env("default")
|
||||
assert result == str(tmp_path / ".hermes")
|
||||
|
||||
def test_nonexistent_raises_file_not_found(self, profile_env):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
resolve_profile_env("nonexistent")
|
||||
|
||||
def test_invalid_name_raises_value_error(self, profile_env):
|
||||
with pytest.raises(ValueError):
|
||||
resolve_profile_env("INVALID!")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestAliasCollision
|
||||
# ===================================================================
|
||||
|
||||
class TestAliasCollision:
|
||||
"""Tests for check_alias_collision()."""
|
||||
|
||||
def test_normal_name_returns_none(self, profile_env):
|
||||
# Mock 'which' to return not-found
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=1, stdout="")
|
||||
result = check_alias_collision("mybot")
|
||||
assert result is None
|
||||
|
||||
def test_reserved_name_returns_message(self, profile_env):
|
||||
result = check_alias_collision("hermes")
|
||||
assert result is not None
|
||||
assert "reserved" in result.lower()
|
||||
|
||||
def test_subcommand_returns_message(self, profile_env):
|
||||
result = check_alias_collision("chat")
|
||||
assert result is not None
|
||||
assert "subcommand" in result.lower()
|
||||
|
||||
def test_default_is_reserved(self, profile_env):
|
||||
result = check_alias_collision("default")
|
||||
assert result is not None
|
||||
assert "reserved" in result.lower()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestRenameProfile
|
||||
# ===================================================================
|
||||
|
||||
class TestRenameProfile:
|
||||
"""Tests for rename_profile()."""
|
||||
|
||||
def test_renames_directory(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
create_profile("oldname", no_alias=True)
|
||||
old_dir = tmp_path / ".hermes" / "profiles" / "oldname"
|
||||
assert old_dir.is_dir()
|
||||
|
||||
# Mock alias collision to avoid subprocess calls
|
||||
with patch("hermes_cli.profiles.check_alias_collision", return_value="skip"):
|
||||
new_dir = rename_profile("oldname", "newname")
|
||||
|
||||
assert not old_dir.is_dir()
|
||||
assert new_dir.is_dir()
|
||||
assert new_dir == tmp_path / ".hermes" / "profiles" / "newname"
|
||||
|
||||
def test_default_raises_value_error(self, profile_env):
|
||||
with pytest.raises(ValueError, match="default"):
|
||||
rename_profile("default", "newname")
|
||||
|
||||
def test_rename_to_default_raises_value_error(self, profile_env):
|
||||
create_profile("coder", no_alias=True)
|
||||
with pytest.raises(ValueError, match="default"):
|
||||
rename_profile("coder", "default")
|
||||
|
||||
def test_nonexistent_raises_file_not_found(self, profile_env):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
rename_profile("nonexistent", "newname")
|
||||
|
||||
def test_target_exists_raises_file_exists(self, profile_env):
|
||||
create_profile("alpha", no_alias=True)
|
||||
create_profile("beta", no_alias=True)
|
||||
with pytest.raises(FileExistsError):
|
||||
rename_profile("alpha", "beta")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestExportImport
|
||||
# ===================================================================
|
||||
|
||||
class TestExportImport:
|
||||
"""Tests for export_profile() / import_profile()."""
|
||||
|
||||
def test_export_creates_tar_gz(self, profile_env, tmp_path):
|
||||
create_profile("coder", no_alias=True)
|
||||
# Put a marker file so we can verify content
|
||||
profile_dir = get_profile_dir("coder")
|
||||
(profile_dir / "marker.txt").write_text("hello")
|
||||
|
||||
output = tmp_path / "export" / "coder.tar.gz"
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
result = export_profile("coder", str(output))
|
||||
|
||||
assert Path(result).exists()
|
||||
assert tarfile.is_tarfile(str(result))
|
||||
|
||||
def test_import_restores_from_archive(self, profile_env, tmp_path):
|
||||
# Create and export a profile
|
||||
create_profile("coder", no_alias=True)
|
||||
profile_dir = get_profile_dir("coder")
|
||||
(profile_dir / "marker.txt").write_text("hello")
|
||||
|
||||
archive_path = tmp_path / "export" / "coder.tar.gz"
|
||||
archive_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
export_profile("coder", str(archive_path))
|
||||
|
||||
# Delete the profile, then import it back under a new name
|
||||
import shutil
|
||||
shutil.rmtree(profile_dir)
|
||||
assert not profile_dir.is_dir()
|
||||
|
||||
imported = import_profile(str(archive_path), name="coder")
|
||||
assert imported.is_dir()
|
||||
assert (imported / "marker.txt").read_text() == "hello"
|
||||
|
||||
def test_import_to_existing_name_raises(self, profile_env, tmp_path):
|
||||
create_profile("coder", no_alias=True)
|
||||
profile_dir = get_profile_dir("coder")
|
||||
|
||||
archive_path = tmp_path / "export" / "coder.tar.gz"
|
||||
archive_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
export_profile("coder", str(archive_path))
|
||||
|
||||
# Importing to same existing name should fail
|
||||
with pytest.raises(FileExistsError):
|
||||
import_profile(str(archive_path), name="coder")
|
||||
|
||||
def test_export_nonexistent_raises(self, profile_env, tmp_path):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
export_profile("nonexistent", str(tmp_path / "out.tar.gz"))
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestProfileIsolation
|
||||
# ===================================================================
|
||||
|
||||
class TestProfileIsolation:
|
||||
"""Verify that two profiles have completely separate paths."""
|
||||
|
||||
def test_separate_config_paths(self, profile_env):
|
||||
create_profile("alpha", no_alias=True)
|
||||
create_profile("beta", no_alias=True)
|
||||
alpha_dir = get_profile_dir("alpha")
|
||||
beta_dir = get_profile_dir("beta")
|
||||
assert alpha_dir / "config.yaml" != beta_dir / "config.yaml"
|
||||
assert str(alpha_dir) not in str(beta_dir)
|
||||
|
||||
def test_separate_state_db_paths(self, profile_env):
|
||||
alpha_dir = get_profile_dir("alpha")
|
||||
beta_dir = get_profile_dir("beta")
|
||||
assert alpha_dir / "state.db" != beta_dir / "state.db"
|
||||
|
||||
def test_separate_skills_paths(self, profile_env):
|
||||
create_profile("alpha", no_alias=True)
|
||||
create_profile("beta", no_alias=True)
|
||||
alpha_dir = get_profile_dir("alpha")
|
||||
beta_dir = get_profile_dir("beta")
|
||||
assert alpha_dir / "skills" != beta_dir / "skills"
|
||||
# Verify both exist and are independent dirs
|
||||
assert (alpha_dir / "skills").is_dir()
|
||||
assert (beta_dir / "skills").is_dir()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestCompletion
|
||||
# ===================================================================
|
||||
|
||||
class TestCompletion:
|
||||
"""Tests for bash/zsh completion generators."""
|
||||
|
||||
def test_bash_completion_contains_complete(self):
|
||||
script = generate_bash_completion()
|
||||
assert len(script) > 0
|
||||
assert "complete" in script
|
||||
|
||||
def test_zsh_completion_contains_compdef(self):
|
||||
script = generate_zsh_completion()
|
||||
assert len(script) > 0
|
||||
assert "compdef" in script
|
||||
|
||||
def test_bash_completion_has_hermes_profiles_function(self):
|
||||
script = generate_bash_completion()
|
||||
assert "_hermes_profiles" in script
|
||||
|
||||
def test_zsh_completion_has_hermes_function(self):
|
||||
script = generate_zsh_completion()
|
||||
assert "_hermes" in script
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestGetProfilesRoot / TestGetDefaultHermesHome (internal helpers)
|
||||
# ===================================================================
|
||||
|
||||
class TestInternalHelpers:
|
||||
"""Tests for _get_profiles_root() and _get_default_hermes_home()."""
|
||||
|
||||
def test_profiles_root_under_home(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
root = _get_profiles_root()
|
||||
assert root == tmp_path / ".hermes" / "profiles"
|
||||
|
||||
def test_default_hermes_home(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
home = _get_default_hermes_home()
|
||||
assert home == tmp_path / ".hermes"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Edge cases and additional coverage
|
||||
# ===================================================================
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Additional edge-case tests."""
|
||||
|
||||
def test_create_profile_returns_correct_path(self, profile_env):
|
||||
tmp_path = profile_env
|
||||
result = create_profile("mybot", no_alias=True)
|
||||
expected = tmp_path / ".hermes" / "profiles" / "mybot"
|
||||
assert result == expected
|
||||
|
||||
def test_list_profiles_default_info_fields(self, profile_env):
|
||||
profiles = list_profiles()
|
||||
default = [p for p in profiles if p.name == "default"][0]
|
||||
assert default.is_default is True
|
||||
assert default.gateway_running is False
|
||||
assert default.skill_count == 0
|
||||
|
||||
def test_gateway_running_check_with_pid_file(self, profile_env):
|
||||
"""Verify _check_gateway_running reads pid file and probes os.kill."""
|
||||
from hermes_cli.profiles import _check_gateway_running
|
||||
tmp_path = profile_env
|
||||
default_home = tmp_path / ".hermes"
|
||||
|
||||
# No pid file -> not running
|
||||
assert _check_gateway_running(default_home) is False
|
||||
|
||||
# Write a PID file with a JSON payload
|
||||
pid_file = default_home / "gateway.pid"
|
||||
pid_file.write_text(json.dumps({"pid": 99999}))
|
||||
|
||||
# os.kill(99999, 0) should raise ProcessLookupError -> not running
|
||||
assert _check_gateway_running(default_home) is False
|
||||
|
||||
# Mock os.kill to simulate a running process
|
||||
with patch("os.kill", return_value=None):
|
||||
assert _check_gateway_running(default_home) is True
|
||||
|
||||
def test_gateway_running_check_plain_pid(self, profile_env):
|
||||
"""Pid file containing just a number (legacy format)."""
|
||||
from hermes_cli.profiles import _check_gateway_running
|
||||
tmp_path = profile_env
|
||||
default_home = tmp_path / ".hermes"
|
||||
pid_file = default_home / "gateway.pid"
|
||||
pid_file.write_text("99999")
|
||||
|
||||
with patch("os.kill", return_value=None):
|
||||
assert _check_gateway_running(default_home) is True
|
||||
|
||||
def test_profile_name_boundary_single_char(self):
|
||||
"""Single alphanumeric character is valid."""
|
||||
validate_profile_name("a")
|
||||
validate_profile_name("1")
|
||||
|
||||
def test_profile_name_boundary_all_hyphens(self):
|
||||
"""Name starting with hyphen is invalid."""
|
||||
with pytest.raises(ValueError):
|
||||
validate_profile_name("-abc")
|
||||
|
||||
def test_profile_name_underscore_start(self):
|
||||
"""Name starting with underscore is invalid (must start with [a-z0-9])."""
|
||||
with pytest.raises(ValueError):
|
||||
validate_profile_name("_abc")
|
||||
|
||||
def test_clone_from_named_profile(self, profile_env):
|
||||
"""Clone config from a named (non-default) profile."""
|
||||
tmp_path = profile_env
|
||||
# Create source profile with config
|
||||
source_dir = create_profile("source", no_alias=True)
|
||||
(source_dir / "config.yaml").write_text("model: cloned")
|
||||
(source_dir / ".env").write_text("SECRET=yes")
|
||||
|
||||
target_dir = create_profile(
|
||||
"target", clone_from="source", clone_config=True, no_alias=True,
|
||||
)
|
||||
assert (target_dir / "config.yaml").read_text() == "model: cloned"
|
||||
assert (target_dir / ".env").read_text() == "SECRET=yes"
|
||||
|
||||
def test_delete_clears_active_profile(self, profile_env):
|
||||
"""Deleting the active profile resets active to default."""
|
||||
tmp_path = profile_env
|
||||
create_profile("coder", no_alias=True)
|
||||
set_active_profile("coder")
|
||||
assert get_active_profile() == "coder"
|
||||
|
||||
with patch("hermes_cli.profiles._cleanup_gateway_service"):
|
||||
delete_profile("coder", yes=True)
|
||||
|
||||
assert get_active_profile() == "default"
|
||||
@@ -0,0 +1,271 @@
|
||||
"""Tests for tool token estimation and curses_ui status_fn support."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ─── Token Estimation Tests ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_estimate_tool_tokens_returns_positive_counts():
|
||||
"""_estimate_tool_tokens should return a non-empty dict with positive values."""
|
||||
from hermes_cli.tools_config import _estimate_tool_tokens, _tool_token_cache
|
||||
|
||||
# Clear cache to force fresh computation
|
||||
import hermes_cli.tools_config as tc
|
||||
tc._tool_token_cache = None
|
||||
|
||||
tokens = _estimate_tool_tokens()
|
||||
|
||||
assert isinstance(tokens, dict)
|
||||
assert len(tokens) > 0
|
||||
for name, count in tokens.items():
|
||||
assert isinstance(name, str)
|
||||
assert isinstance(count, int)
|
||||
assert count > 0, f"Tool {name} has non-positive token count: {count}"
|
||||
|
||||
|
||||
def test_estimate_tool_tokens_is_cached():
|
||||
"""Second call should return the same cached dict object."""
|
||||
import hermes_cli.tools_config as tc
|
||||
tc._tool_token_cache = None
|
||||
|
||||
first = tc._estimate_tool_tokens()
|
||||
second = tc._estimate_tool_tokens()
|
||||
|
||||
assert first is second
|
||||
|
||||
|
||||
def test_estimate_tool_tokens_returns_empty_when_tiktoken_unavailable(monkeypatch):
|
||||
"""Graceful degradation when tiktoken cannot be imported."""
|
||||
import hermes_cli.tools_config as tc
|
||||
tc._tool_token_cache = None
|
||||
|
||||
import builtins
|
||||
real_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "tiktoken":
|
||||
raise ImportError("mocked")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", mock_import)
|
||||
|
||||
result = tc._estimate_tool_tokens()
|
||||
|
||||
assert result == {}
|
||||
|
||||
# Reset cache for other tests
|
||||
tc._tool_token_cache = None
|
||||
|
||||
|
||||
def test_estimate_tool_tokens_covers_known_tools():
|
||||
"""Should include schemas for well-known tools like terminal, web_search."""
|
||||
import hermes_cli.tools_config as tc
|
||||
tc._tool_token_cache = None
|
||||
|
||||
tokens = tc._estimate_tool_tokens()
|
||||
|
||||
# These tools should always be discoverable
|
||||
for expected in ("terminal", "web_search", "read_file"):
|
||||
assert expected in tokens, f"Expected {expected!r} in token estimates"
|
||||
|
||||
|
||||
# ─── Status Function Tests ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_prompt_toolset_checklist_passes_status_fn(monkeypatch):
|
||||
"""_prompt_toolset_checklist should pass a status_fn to curses_checklist."""
|
||||
import hermes_cli.tools_config as tc
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None):
|
||||
captured_kwargs["status_fn"] = status_fn
|
||||
captured_kwargs["title"] = title
|
||||
return selected # Return pre-selected unchanged
|
||||
|
||||
monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist)
|
||||
|
||||
tc._prompt_toolset_checklist("CLI", {"web", "terminal"})
|
||||
|
||||
assert "status_fn" in captured_kwargs
|
||||
# If tiktoken is available, status_fn should be set
|
||||
tokens = tc._estimate_tool_tokens()
|
||||
if tokens:
|
||||
assert captured_kwargs["status_fn"] is not None
|
||||
|
||||
|
||||
def test_status_fn_returns_formatted_token_count(monkeypatch):
|
||||
"""The status_fn should return a human-readable token count string."""
|
||||
import hermes_cli.tools_config as tc
|
||||
from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None):
|
||||
captured["status_fn"] = status_fn
|
||||
return selected
|
||||
|
||||
monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist)
|
||||
|
||||
tc._prompt_toolset_checklist("CLI", {"web", "terminal"})
|
||||
|
||||
status_fn = captured.get("status_fn")
|
||||
if status_fn is None:
|
||||
pytest.skip("tiktoken unavailable; status_fn not created")
|
||||
|
||||
# Find the indices for web and terminal
|
||||
idx_map = {ts_key: i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)}
|
||||
|
||||
# Call status_fn with web + terminal selected
|
||||
result = status_fn({idx_map["web"], idx_map["terminal"]})
|
||||
assert "tokens" in result
|
||||
assert "Est. tool context" in result
|
||||
|
||||
|
||||
def test_status_fn_deduplicates_overlapping_tools(monkeypatch):
|
||||
"""When toolsets overlap (browser includes web_search), tokens should not double-count."""
|
||||
import hermes_cli.tools_config as tc
|
||||
from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None):
|
||||
captured["status_fn"] = status_fn
|
||||
return selected
|
||||
|
||||
monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist)
|
||||
|
||||
tc._prompt_toolset_checklist("CLI", {"web"})
|
||||
|
||||
status_fn = captured.get("status_fn")
|
||||
if status_fn is None:
|
||||
pytest.skip("tiktoken unavailable; status_fn not created")
|
||||
|
||||
idx_map = {ts_key: i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)}
|
||||
|
||||
# web alone
|
||||
web_only = status_fn({idx_map["web"]})
|
||||
# browser includes web_search, so browser + web should not double-count web_search
|
||||
browser_only = status_fn({idx_map["browser"]})
|
||||
both = status_fn({idx_map["web"], idx_map["browser"]})
|
||||
|
||||
# Extract numeric token counts from strings like "~8.3k tokens" or "~350 tokens"
|
||||
import re
|
||||
|
||||
def parse_tokens(s):
|
||||
m = re.search(r"~([\d.]+)k?\s+tokens", s)
|
||||
if not m:
|
||||
return 0
|
||||
val = float(m.group(1))
|
||||
if "k" in s[m.start():m.end()]:
|
||||
val *= 1000
|
||||
return val
|
||||
|
||||
web_tok = parse_tokens(web_only)
|
||||
browser_tok = parse_tokens(browser_only)
|
||||
both_tok = parse_tokens(both)
|
||||
|
||||
# Both together should be LESS than naive sum (due to web_search dedup)
|
||||
naive_sum = web_tok + browser_tok
|
||||
assert both_tok < naive_sum, (
|
||||
f"Expected deduplication: web({web_tok}) + browser({browser_tok}) = {naive_sum} "
|
||||
f"but combined = {both_tok}"
|
||||
)
|
||||
|
||||
|
||||
def test_status_fn_empty_selection():
|
||||
"""Status function with no tools selected should return ~0 tokens."""
|
||||
import hermes_cli.tools_config as tc
|
||||
|
||||
tc._tool_token_cache = None
|
||||
tokens = tc._estimate_tool_tokens()
|
||||
if not tokens:
|
||||
pytest.skip("tiktoken unavailable")
|
||||
|
||||
from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
ts_keys = [ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS]
|
||||
|
||||
def status_fn(chosen: set) -> str:
|
||||
all_tools: set = set()
|
||||
for idx in chosen:
|
||||
all_tools.update(resolve_toolset(ts_keys[idx]))
|
||||
total = sum(tokens.get(name, 0) for name in all_tools)
|
||||
if total >= 1000:
|
||||
return f"Est. tool context: ~{total / 1000:.1f}k tokens"
|
||||
return f"Est. tool context: ~{total} tokens"
|
||||
|
||||
result = status_fn(set())
|
||||
assert "~0 tokens" in result
|
||||
|
||||
|
||||
# ─── Curses UI Status Bar Tests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_curses_checklist_numbered_fallback_shows_status(monkeypatch, capsys):
|
||||
"""The numbered fallback should print the status_fn output."""
|
||||
from hermes_cli.curses_ui import _numbered_fallback
|
||||
|
||||
def my_status(chosen):
|
||||
return f"Selected {len(chosen)} items"
|
||||
|
||||
# Simulate user pressing Enter immediately (empty input → confirm)
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": "")
|
||||
|
||||
result = _numbered_fallback(
|
||||
"Test title",
|
||||
["Item A", "Item B", "Item C"],
|
||||
{0, 2},
|
||||
{0, 2},
|
||||
status_fn=my_status,
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Selected 2 items" in captured.out
|
||||
assert result == {0, 2}
|
||||
|
||||
|
||||
def test_curses_checklist_numbered_fallback_without_status(monkeypatch, capsys):
|
||||
"""The numbered fallback should work fine without status_fn."""
|
||||
from hermes_cli.curses_ui import _numbered_fallback
|
||||
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": "")
|
||||
|
||||
result = _numbered_fallback(
|
||||
"Test title",
|
||||
["Item A", "Item B"],
|
||||
{0},
|
||||
{0},
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Est. tool context" not in captured.out
|
||||
assert result == {0}
|
||||
|
||||
|
||||
# ─── Registry get_schema Tests ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_registry_get_schema_returns_schema():
|
||||
"""registry.get_schema() should return a tool's schema dict."""
|
||||
from tools.registry import registry
|
||||
|
||||
# Import to trigger discovery
|
||||
import model_tools # noqa: F401
|
||||
|
||||
schema = registry.get_schema("terminal")
|
||||
assert schema is not None
|
||||
assert "name" in schema
|
||||
assert schema["name"] == "terminal"
|
||||
assert "parameters" in schema
|
||||
|
||||
|
||||
def test_registry_get_schema_returns_none_for_unknown():
|
||||
"""registry.get_schema() should return None for unknown tools."""
|
||||
from tools.registry import registry
|
||||
|
||||
assert registry.get_schema("nonexistent_tool_xyz") is None
|
||||
@@ -0,0 +1,427 @@
|
||||
"""Tests for optional-skills/productivity/memento-flashcards/scripts/memento_cards.py"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the scripts dir so we can import the module directly
|
||||
SCRIPTS_DIR = Path(__file__).resolve().parents[2] / "optional-skills" / "productivity" / "memento-flashcards" / "scripts"
|
||||
sys.path.insert(0, str(SCRIPTS_DIR))
|
||||
|
||||
import memento_cards
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_data(tmp_path, monkeypatch):
|
||||
"""Redirect card storage to a temp directory for every test."""
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
monkeypatch.setattr(memento_cards, "DATA_DIR", data_dir)
|
||||
monkeypatch.setattr(memento_cards, "CARDS_FILE", data_dir / "cards.json")
|
||||
return data_dir
|
||||
|
||||
|
||||
def _run(capsys, argv: list[str]) -> dict:
|
||||
"""Run main() with given argv and return parsed JSON output."""
|
||||
with mock.patch("sys.argv", ["memento_cards"] + argv):
|
||||
memento_cards.main()
|
||||
captured = capsys.readouterr()
|
||||
return json.loads(captured.out)
|
||||
|
||||
|
||||
# ── Add / List / Delete ──────────────────────────────────────────────────────
|
||||
|
||||
class TestCardCRUD:
|
||||
def test_add_creates_card(self, capsys):
|
||||
result = _run(capsys, ["add", "--question", "What is 2+2?", "--answer", "4", "--collection", "Math"])
|
||||
assert result["ok"] is True
|
||||
card = result["card"]
|
||||
assert card["question"] == "What is 2+2?"
|
||||
assert card["answer"] == "4"
|
||||
assert card["collection"] == "Math"
|
||||
assert card["status"] == "learning"
|
||||
assert card["ease_streak"] == 0
|
||||
uuid.UUID(card["id"]) # validates it's a real UUID
|
||||
|
||||
def test_add_default_collection(self, capsys):
|
||||
result = _run(capsys, ["add", "--question", "Q?", "--answer", "A"])
|
||||
assert result["card"]["collection"] == "General"
|
||||
|
||||
def test_list_all(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "C1"])
|
||||
_run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "C2"])
|
||||
result = _run(capsys, ["list"])
|
||||
assert result["count"] == 2
|
||||
|
||||
def test_list_by_collection(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "C1"])
|
||||
_run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "C2"])
|
||||
result = _run(capsys, ["list", "--collection", "C1"])
|
||||
assert result["count"] == 1
|
||||
assert result["cards"][0]["collection"] == "C1"
|
||||
|
||||
def test_list_by_status(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q1", "--answer", "A1"])
|
||||
result = _run(capsys, ["list", "--status", "learning"])
|
||||
assert result["count"] == 1
|
||||
result = _run(capsys, ["list", "--status", "retired"])
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_delete_card(self, capsys):
|
||||
result = _run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = result["card"]["id"]
|
||||
del_result = _run(capsys, ["delete", "--id", card_id])
|
||||
assert del_result["ok"] is True
|
||||
assert del_result["deleted"] == card_id
|
||||
# Verify gone
|
||||
list_result = _run(capsys, ["list"])
|
||||
assert list_result["count"] == 0
|
||||
|
||||
def test_delete_nonexistent(self, capsys):
|
||||
with pytest.raises(SystemExit):
|
||||
_run(capsys, ["delete", "--id", "nonexistent"])
|
||||
|
||||
def test_delete_collection(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "ToDelete"])
|
||||
_run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "ToDelete"])
|
||||
_run(capsys, ["add", "--question", "Q3", "--answer", "A3", "--collection", "Keep"])
|
||||
result = _run(capsys, ["delete-collection", "--collection", "ToDelete"])
|
||||
assert result["ok"] is True
|
||||
assert result["deleted_count"] == 2
|
||||
list_result = _run(capsys, ["list"])
|
||||
assert list_result["count"] == 1
|
||||
assert list_result["cards"][0]["collection"] == "Keep"
|
||||
|
||||
|
||||
# ── Due Filtering ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestDueFiltering:
|
||||
def test_new_card_is_due(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
result = _run(capsys, ["due"])
|
||||
assert result["count"] == 1
|
||||
|
||||
def test_future_card_not_due(self, capsys, monkeypatch):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
# Rate it good (pushes next_review_at to +3 days)
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
_run(capsys, ["rate", "--id", card_id, "--rating", "good"])
|
||||
result = _run(capsys, ["due"])
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_retired_card_not_due(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
_run(capsys, ["rate", "--id", card_id, "--rating", "retire"])
|
||||
result = _run(capsys, ["due"])
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_due_with_collection_filter(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "C1"])
|
||||
_run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "C2"])
|
||||
result = _run(capsys, ["due", "--collection", "C1"])
|
||||
assert result["count"] == 1
|
||||
assert result["cards"][0]["collection"] == "C1"
|
||||
|
||||
|
||||
# ── Rating and Rescheduling ──────────────────────────────────────────────────
|
||||
|
||||
class TestRating:
|
||||
def test_hard_adds_1_day(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
before = datetime.now(timezone.utc)
|
||||
result = _run(capsys, ["rate", "--id", card_id, "--rating", "hard"])
|
||||
after = datetime.now(timezone.utc)
|
||||
next_review = datetime.fromisoformat(result["card"]["next_review_at"])
|
||||
assert before + timedelta(days=1) <= next_review <= after + timedelta(days=1)
|
||||
assert result["card"]["ease_streak"] == 0
|
||||
|
||||
def test_good_adds_3_days(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
before = datetime.now(timezone.utc)
|
||||
result = _run(capsys, ["rate", "--id", card_id, "--rating", "good"])
|
||||
next_review = datetime.fromisoformat(result["card"]["next_review_at"])
|
||||
assert next_review >= before + timedelta(days=3)
|
||||
assert result["card"]["ease_streak"] == 0
|
||||
|
||||
def test_easy_adds_7_days_and_increments_streak(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
result = _run(capsys, ["rate", "--id", card_id, "--rating", "easy"])
|
||||
assert result["card"]["ease_streak"] == 1
|
||||
assert result["card"]["status"] == "learning"
|
||||
|
||||
def test_retire_sets_retired(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
result = _run(capsys, ["rate", "--id", card_id, "--rating", "retire"])
|
||||
assert result["card"]["status"] == "retired"
|
||||
assert result["card"]["ease_streak"] == 0
|
||||
|
||||
def test_auto_retire_after_3_easys(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
|
||||
# Force card to be due by manipulating next_review_at through rate
|
||||
for i in range(3):
|
||||
# Load and directly set next_review_at to now so it's ratable
|
||||
data = memento_cards._load()
|
||||
for c in data["cards"]:
|
||||
if c["id"] == card_id:
|
||||
c["next_review_at"] = memento_cards._iso(memento_cards._now())
|
||||
memento_cards._save(data)
|
||||
|
||||
result = _run(capsys, ["rate", "--id", card_id, "--rating", "easy"])
|
||||
|
||||
assert result["card"]["ease_streak"] == 3
|
||||
assert result["card"]["status"] == "retired"
|
||||
|
||||
def test_hard_resets_ease_streak(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
|
||||
# Easy twice
|
||||
for _ in range(2):
|
||||
data = memento_cards._load()
|
||||
for c in data["cards"]:
|
||||
if c["id"] == card_id:
|
||||
c["next_review_at"] = memento_cards._iso(memento_cards._now())
|
||||
memento_cards._save(data)
|
||||
_run(capsys, ["rate", "--id", card_id, "--rating", "easy"])
|
||||
|
||||
# Verify streak is 2
|
||||
check = _run(capsys, ["list"])
|
||||
assert check["cards"][0]["ease_streak"] == 2
|
||||
|
||||
# Hard resets
|
||||
data = memento_cards._load()
|
||||
for c in data["cards"]:
|
||||
if c["id"] == card_id:
|
||||
c["next_review_at"] = memento_cards._iso(memento_cards._now())
|
||||
memento_cards._save(data)
|
||||
result = _run(capsys, ["rate", "--id", card_id, "--rating", "hard"])
|
||||
assert result["card"]["ease_streak"] == 0
|
||||
assert result["card"]["status"] == "learning"
|
||||
|
||||
def test_rate_nonexistent_card(self, capsys):
|
||||
with pytest.raises(SystemExit):
|
||||
_run(capsys, ["rate", "--id", "nonexistent", "--rating", "easy"])
|
||||
|
||||
|
||||
# ── CSV Export/Import ────────────────────────────────────────────────────────
|
||||
|
||||
class TestCSV:
|
||||
def test_export_import_roundtrip(self, capsys, tmp_path):
|
||||
_run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "C1"])
|
||||
_run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "C2"])
|
||||
|
||||
csv_path = str(tmp_path / "export.csv")
|
||||
result = _run(capsys, ["export", "--output", csv_path])
|
||||
assert result["ok"] is True
|
||||
assert result["exported"] == 2
|
||||
|
||||
# Verify CSV content
|
||||
with open(csv_path, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
rows = list(reader)
|
||||
assert len(rows) == 2
|
||||
assert rows[0] == ["Q1", "A1", "C1"]
|
||||
assert rows[1] == ["Q2", "A2", "C2"]
|
||||
|
||||
# Delete all and reimport
|
||||
data = memento_cards._load()
|
||||
data["cards"] = []
|
||||
memento_cards._save(data)
|
||||
|
||||
result = _run(capsys, ["import", "--file", csv_path, "--collection", "Fallback"])
|
||||
assert result["ok"] is True
|
||||
assert result["imported"] == 2
|
||||
|
||||
# Verify imported cards use CSV collection column
|
||||
list_result = _run(capsys, ["list"])
|
||||
collections = {c["collection"] for c in list_result["cards"]}
|
||||
assert collections == {"C1", "C2"}
|
||||
|
||||
def test_import_without_collection_column(self, capsys, tmp_path):
|
||||
csv_path = str(tmp_path / "no_col.csv")
|
||||
with open(csv_path, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["Q1", "A1"])
|
||||
writer.writerow(["Q2", "A2"])
|
||||
|
||||
result = _run(capsys, ["import", "--file", csv_path, "--collection", "MyDeck"])
|
||||
assert result["imported"] == 2
|
||||
|
||||
list_result = _run(capsys, ["list"])
|
||||
assert all(c["collection"] == "MyDeck" for c in list_result["cards"])
|
||||
|
||||
def test_import_skips_empty_rows(self, capsys, tmp_path):
|
||||
csv_path = str(tmp_path / "sparse.csv")
|
||||
with open(csv_path, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["Q1", "A1"])
|
||||
writer.writerow(["", ""]) # empty
|
||||
writer.writerow(["Q2"]) # only one column
|
||||
writer.writerow(["Q3", "A3"])
|
||||
|
||||
result = _run(capsys, ["import", "--file", csv_path, "--collection", "Test"])
|
||||
assert result["imported"] == 2
|
||||
|
||||
def test_import_nonexistent_file(self, capsys, tmp_path):
|
||||
with pytest.raises(SystemExit):
|
||||
_run(capsys, ["import", "--file", str(tmp_path / "nope.csv"), "--collection", "X"])
|
||||
|
||||
|
||||
# ── Quiz Batch Add ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestQuizBatchAdd:
|
||||
def test_add_quiz_creates_cards(self, capsys):
|
||||
questions = json.dumps([
|
||||
{"question": "Q1?", "answer": "A1"},
|
||||
{"question": "Q2?", "answer": "A2"},
|
||||
])
|
||||
result = _run(capsys, ["add-quiz", "--video-id", "abc123", "--questions", questions, "--collection", "Quiz - Test"])
|
||||
assert result["ok"] is True
|
||||
assert result["created_count"] == 2
|
||||
for card in result["cards"]:
|
||||
assert card["video_id"] == "abc123"
|
||||
assert card["collection"] == "Quiz - Test"
|
||||
|
||||
def test_add_quiz_deduplicates_by_video_id(self, capsys):
|
||||
questions = json.dumps([{"question": "Q?", "answer": "A"}])
|
||||
_run(capsys, ["add-quiz", "--video-id", "dup1", "--questions", questions])
|
||||
result = _run(capsys, ["add-quiz", "--video-id", "dup1", "--questions", questions])
|
||||
assert result["ok"] is True
|
||||
assert result["skipped"] is True
|
||||
assert result["reason"] == "duplicate_video_id"
|
||||
# Only 1 card total (not 2)
|
||||
list_result = _run(capsys, ["list"])
|
||||
assert list_result["count"] == 1
|
||||
|
||||
def test_add_quiz_invalid_json(self, capsys):
|
||||
with pytest.raises(SystemExit):
|
||||
_run(capsys, ["add-quiz", "--video-id", "x", "--questions", "not json"])
|
||||
|
||||
|
||||
# ── Statistics ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestStats:
|
||||
def test_stats_empty(self, capsys):
|
||||
result = _run(capsys, ["stats"])
|
||||
assert result["total"] == 0
|
||||
assert result["learning"] == 0
|
||||
assert result["retired"] == 0
|
||||
assert result["due_now"] == 0
|
||||
|
||||
def test_stats_counts(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "C1"])
|
||||
_run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "C1"])
|
||||
_run(capsys, ["add", "--question", "Q3", "--answer", "A3", "--collection", "C2"])
|
||||
|
||||
# Retire one
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
_run(capsys, ["rate", "--id", card_id, "--rating", "retire"])
|
||||
|
||||
result = _run(capsys, ["stats"])
|
||||
assert result["total"] == 3
|
||||
assert result["learning"] == 2
|
||||
assert result["retired"] == 1
|
||||
assert result["due_now"] == 2 # 2 learning cards still due
|
||||
assert result["collections"] == {"C1": 2, "C2": 1}
|
||||
|
||||
|
||||
# ── Edge Cases ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_empty_deck_operations(self, capsys):
|
||||
"""Operations on empty deck shouldn't crash."""
|
||||
result = _run(capsys, ["due"])
|
||||
assert result["count"] == 0
|
||||
result = _run(capsys, ["list"])
|
||||
assert result["count"] == 0
|
||||
result = _run(capsys, ["stats"])
|
||||
assert result["total"] == 0
|
||||
|
||||
def test_corrupt_json_recovery(self, capsys):
|
||||
"""Corrupt JSON file should be treated as empty."""
|
||||
memento_cards.DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with open(memento_cards.CARDS_FILE, "w") as f:
|
||||
f.write("{corrupted json...")
|
||||
result = _run(capsys, ["list"])
|
||||
assert result["count"] == 0
|
||||
# Can still add
|
||||
result = _run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
assert result["ok"] is True
|
||||
|
||||
def test_missing_cards_key_recovery(self, capsys):
|
||||
"""JSON without 'cards' key should be treated as empty."""
|
||||
memento_cards.DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with open(memento_cards.CARDS_FILE, "w") as f:
|
||||
json.dump({"version": 1}, f)
|
||||
result = _run(capsys, ["list"])
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_atomic_write_creates_dir(self, capsys):
|
||||
"""Data dir is created automatically if missing."""
|
||||
import shutil
|
||||
if memento_cards.DATA_DIR.exists():
|
||||
shutil.rmtree(memento_cards.DATA_DIR)
|
||||
result = _run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
assert result["ok"] is True
|
||||
assert memento_cards.CARDS_FILE.exists()
|
||||
|
||||
def test_delete_collection_empty(self, capsys):
|
||||
"""Deleting a nonexistent collection succeeds with 0 deleted."""
|
||||
result = _run(capsys, ["delete-collection", "--collection", "Nope"])
|
||||
assert result["ok"] is True
|
||||
assert result["deleted_count"] == 0
|
||||
|
||||
|
||||
# ── User Answer Tracking ────────────────────────────────────────────────────
|
||||
|
||||
class TestUserAnswer:
|
||||
def test_rate_stores_user_answer(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
result = _run(capsys, ["rate", "--id", card_id, "--rating", "easy",
|
||||
"--user-answer", "my answer"])
|
||||
assert result["card"]["last_user_answer"] == "my answer"
|
||||
|
||||
def test_rate_without_user_answer_keeps_null(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
result = _run(capsys, ["rate", "--id", card_id, "--rating", "easy"])
|
||||
assert result["card"]["last_user_answer"] is None
|
||||
|
||||
def test_new_card_has_last_user_answer_null(self, capsys):
|
||||
result = _run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
assert result["card"]["last_user_answer"] is None
|
||||
|
||||
def test_user_answer_persists_in_list(self, capsys):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
_run(capsys, ["rate", "--id", card_id, "--rating", "easy",
|
||||
"--user-answer", "my answer"])
|
||||
result = _run(capsys, ["list"])
|
||||
assert result["cards"][0]["last_user_answer"] == "my answer"
|
||||
|
||||
def test_export_excludes_user_answer(self, capsys, tmp_path):
|
||||
_run(capsys, ["add", "--question", "Q", "--answer", "A"])
|
||||
card_id = _run(capsys, ["list"])["cards"][0]["id"]
|
||||
_run(capsys, ["rate", "--id", card_id, "--rating", "easy",
|
||||
"--user-answer", "my answer"])
|
||||
csv_path = str(tmp_path / "export.csv")
|
||||
_run(capsys, ["export", "--output", csv_path])
|
||||
with open(csv_path) as f:
|
||||
rows = list(csv.reader(f))
|
||||
# CSV stays 3-column (question, answer, collection) — user_answer is internal only
|
||||
assert len(rows[0]) == 3
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Tests for optional-skills/productivity/memento-flashcards/scripts/youtube_quiz.py"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
SCRIPTS_DIR = Path(__file__).resolve().parents[2] / "optional-skills" / "productivity" / "memento-flashcards" / "scripts"
|
||||
sys.path.insert(0, str(SCRIPTS_DIR))
|
||||
|
||||
import youtube_quiz
|
||||
|
||||
|
||||
def _run(capsys, argv: list[str]) -> dict:
|
||||
"""Run main() with given argv and return parsed JSON output."""
|
||||
with mock.patch("sys.argv", ["youtube_quiz"] + argv):
|
||||
youtube_quiz.main()
|
||||
captured = capsys.readouterr()
|
||||
return json.loads(captured.out)
|
||||
|
||||
|
||||
class TestNormalizeSegments:
|
||||
def test_basic(self):
|
||||
segments = [{"text": "hello "}, {"text": " world"}]
|
||||
assert youtube_quiz._normalize_segments(segments) == "hello world"
|
||||
|
||||
def test_empty_segments(self):
|
||||
assert youtube_quiz._normalize_segments([]) == ""
|
||||
|
||||
def test_whitespace_only(self):
|
||||
assert youtube_quiz._normalize_segments([{"text": " "}, {"text": " "}]) == ""
|
||||
|
||||
def test_collapses_multiple_spaces(self):
|
||||
segments = [{"text": "a b"}, {"text": "c d"}]
|
||||
assert youtube_quiz._normalize_segments(segments) == "a b c d"
|
||||
|
||||
|
||||
class TestFetchMissingDependency:
|
||||
def test_missing_youtube_transcript_api(self, capsys, monkeypatch):
|
||||
"""When youtube-transcript-api is not installed, report the error."""
|
||||
import builtins
|
||||
real_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "youtube_transcript_api":
|
||||
raise ImportError("No module named 'youtube_transcript_api'")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", mock_import)
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
_run(capsys, ["fetch", "test123"])
|
||||
|
||||
captured = capsys.readouterr()
|
||||
result = json.loads(captured.out)
|
||||
assert result["ok"] is False
|
||||
assert result["error"] == "missing_dependency"
|
||||
assert "pip install" in result["message"]
|
||||
|
||||
|
||||
class TestFetchWithMockedAPI:
|
||||
def _make_mock_module(self, segments=None, raise_exc=None):
|
||||
"""Create a mock youtube_transcript_api module."""
|
||||
mock_module = mock.MagicMock()
|
||||
|
||||
mock_api_instance = mock.MagicMock()
|
||||
mock_module.YouTubeTranscriptApi.return_value = mock_api_instance
|
||||
|
||||
if raise_exc:
|
||||
mock_api_instance.fetch.side_effect = raise_exc
|
||||
else:
|
||||
raw_data = segments or [{"text": "Hello world"}]
|
||||
result = mock.MagicMock()
|
||||
result.to_raw_data.return_value = raw_data
|
||||
mock_api_instance.fetch.return_value = result
|
||||
|
||||
return mock_module
|
||||
|
||||
def test_successful_fetch(self, capsys):
|
||||
mock_mod = self._make_mock_module(
|
||||
segments=[{"text": "This is a test"}, {"text": "transcript segment"}]
|
||||
)
|
||||
with mock.patch.dict("sys.modules", {"youtube_transcript_api": mock_mod}):
|
||||
result = _run(capsys, ["fetch", "abc123"])
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["video_id"] == "abc123"
|
||||
assert "This is a test" in result["transcript"]
|
||||
assert "transcript segment" in result["transcript"]
|
||||
|
||||
def test_fetch_error(self, capsys):
|
||||
mock_mod = self._make_mock_module(raise_exc=Exception("Video unavailable"))
|
||||
with mock.patch.dict("sys.modules", {"youtube_transcript_api": mock_mod}):
|
||||
with pytest.raises(SystemExit):
|
||||
_run(capsys, ["fetch", "bad_id"])
|
||||
|
||||
captured = capsys.readouterr()
|
||||
result = json.loads(captured.out)
|
||||
assert result["ok"] is False
|
||||
assert result["error"] == "transcript_unavailable"
|
||||
|
||||
def test_empty_transcript(self, capsys):
|
||||
mock_mod = self._make_mock_module(segments=[{"text": ""}, {"text": " "}])
|
||||
with mock.patch.dict("sys.modules", {"youtube_transcript_api": mock_mod}):
|
||||
with pytest.raises(SystemExit):
|
||||
_run(capsys, ["fetch", "empty_vid"])
|
||||
|
||||
captured = capsys.readouterr()
|
||||
result = json.loads(captured.out)
|
||||
assert result["ok"] is False
|
||||
assert result["error"] == "empty_transcript"
|
||||
|
||||
def test_segments_without_to_raw_data(self, capsys):
|
||||
"""Handle plain list segments (no to_raw_data method)."""
|
||||
mock_mod = mock.MagicMock()
|
||||
mock_api = mock.MagicMock()
|
||||
mock_mod.YouTubeTranscriptApi.return_value = mock_api
|
||||
# Return a plain list (no to_raw_data attribute)
|
||||
mock_api.fetch.return_value = [{"text": "plain list"}]
|
||||
|
||||
with mock.patch.dict("sys.modules", {"youtube_transcript_api": mock_mod}):
|
||||
result = _run(capsys, ["fetch", "plain123"])
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["transcript"] == "plain list"
|
||||
@@ -0,0 +1,252 @@
|
||||
"""Tests for agent resilience features inspired by Ironclaw PRs.
|
||||
|
||||
Feature 1: Discard truncated tool calls on finish_reason=length (#1632)
|
||||
Feature 2: Empty response recovery (#1677 + #1720)
|
||||
Feature 3: Sanitize tool error results (#1639)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure repo root is importable
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _mock_tool_call(name="test_tool", args='{"key": "value"}', tc_id="tc_1"):
|
||||
return SimpleNamespace(
|
||||
id=tc_id,
|
||||
function=SimpleNamespace(name=name, arguments=args),
|
||||
type="function",
|
||||
)
|
||||
|
||||
|
||||
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None, usage=None):
|
||||
msg = SimpleNamespace(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=None,
|
||||
reasoning=None,
|
||||
)
|
||||
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
|
||||
resp = SimpleNamespace(choices=[choice], model="test/model")
|
||||
resp.usage = SimpleNamespace(**usage) if usage else None
|
||||
return resp
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Feature 3: Sanitize tool error results
|
||||
# =========================================================================
|
||||
|
||||
class TestSanitizeToolError:
|
||||
"""Test _sanitize_tool_error helper function in model_tools.py."""
|
||||
|
||||
def test_import(self):
|
||||
"""Verify the sanitize function can be imported."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
assert callable(_sanitize_tool_error)
|
||||
|
||||
def test_truncation(self):
|
||||
"""Error messages longer than 2000 chars are truncated."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
long_msg = "x" * 5000
|
||||
result = _sanitize_tool_error(long_msg)
|
||||
# Account for the [TOOL_ERROR] prefix
|
||||
assert len(result) <= 2000 + len("[TOOL_ERROR] ")
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_xml_tag_stripping(self):
|
||||
"""XML-like boundary tags are stripped from errors."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
error = "<tool_call>Error: file not found</tool_call>"
|
||||
result = _sanitize_tool_error(error)
|
||||
assert "<tool_call>" not in result
|
||||
assert "</tool_call>" not in result
|
||||
assert "file not found" in result
|
||||
|
||||
def test_system_tag_stripping(self):
|
||||
"""System/assistant/user tags are stripped."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
error = "<system>Permission denied</system>"
|
||||
result = _sanitize_tool_error(error)
|
||||
assert "<system>" not in result
|
||||
assert "Permission denied" in result
|
||||
|
||||
def test_code_fence_stripping(self):
|
||||
"""Markdown code fences are stripped."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
error = "```json\n{\"error\": \"bad\"}\n```"
|
||||
result = _sanitize_tool_error(error)
|
||||
assert "```" not in result
|
||||
|
||||
def test_cdata_stripping(self):
|
||||
"""CDATA sections are stripped."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
error = "Error: <![CDATA[some internal data]]> happened"
|
||||
result = _sanitize_tool_error(error)
|
||||
assert "CDATA" not in result
|
||||
assert "happened" in result
|
||||
|
||||
def test_error_format_prefix(self):
|
||||
"""Error is wrapped with [TOOL_ERROR] prefix."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
result = _sanitize_tool_error("something went wrong")
|
||||
assert result.startswith("[TOOL_ERROR]")
|
||||
assert "something went wrong" in result
|
||||
|
||||
def test_short_error_preserved(self):
|
||||
"""Short, clean errors are preserved intact (with prefix)."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
result = _sanitize_tool_error("File not found: /tmp/test.txt")
|
||||
assert result == "[TOOL_ERROR] File not found: /tmp/test.txt"
|
||||
|
||||
def test_handle_function_call_uses_sanitizer(self):
|
||||
"""handle_function_call sanitizes error messages from exceptions."""
|
||||
from model_tools import handle_function_call, _sanitize_tool_error
|
||||
# The registry returns its own error for unknown tools (not via the
|
||||
# except block). Verify the sanitizer is called in the except path
|
||||
# by directly testing what would happen.
|
||||
raw_error = "Error executing bad_tool: <system>Internal traceback</system>"
|
||||
sanitized = _sanitize_tool_error(raw_error)
|
||||
result_json = json.dumps({"error": sanitized}, ensure_ascii=False)
|
||||
parsed = json.loads(result_json)
|
||||
assert "[TOOL_ERROR]" in parsed["error"]
|
||||
assert "<system>" not in parsed["error"]
|
||||
|
||||
def test_mixed_tags_and_long_error(self):
|
||||
"""Complex error with tags AND length > 2000."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
error = "<result>" + ("a" * 3000) + "</result>"
|
||||
result = _sanitize_tool_error(error)
|
||||
assert "<result>" not in result
|
||||
assert "</result>" not in result
|
||||
assert len(result) <= 2020 # prefix + 2000 + ...
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Feature 1: Discard truncated tool calls on finish_reason=length
|
||||
# =========================================================================
|
||||
|
||||
class TestTruncatedToolCallDiscard:
|
||||
"""Test that truncated tool calls (finish_reason=length) are discarded."""
|
||||
|
||||
def test_truncated_tool_calls_message_content(self):
|
||||
"""Verify the truncation nudge message text is correct."""
|
||||
expected_nudge = (
|
||||
'Your previous response was truncated due to context length limits. '
|
||||
'The tool calls were discarded. Please summarize your progress so '
|
||||
'far and continue with a shorter response.'
|
||||
)
|
||||
# This is the message that should be injected into the conversation
|
||||
assert "truncated" in expected_nudge.lower()
|
||||
assert "discarded" in expected_nudge.lower()
|
||||
|
||||
def test_tools_temporarily_disabled_attribute(self):
|
||||
"""Verify the _tools_temporarily_disabled attribute pattern works."""
|
||||
# Test the attribute access pattern used in the implementation
|
||||
obj = SimpleNamespace()
|
||||
assert getattr(obj, '_tools_temporarily_disabled', False) is False
|
||||
obj._tools_temporarily_disabled = True
|
||||
assert getattr(obj, '_tools_temporarily_disabled', False) is True
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Feature 2: Empty response recovery
|
||||
# =========================================================================
|
||||
|
||||
class TestEmptyResponseRecovery:
|
||||
"""Test empty response recovery behavior."""
|
||||
|
||||
def test_empty_response_nudge_text(self):
|
||||
"""Verify the nudge message for empty responses."""
|
||||
nudge = "Your previous response was empty. Please continue with the task."
|
||||
assert "empty" in nudge.lower()
|
||||
assert "continue" in nudge.lower()
|
||||
|
||||
def test_prior_meaningful_output_detection(self):
|
||||
"""Test logic for detecting prior meaningful output in messages."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Here is a detailed response about your question."},
|
||||
{"role": "user", "content": "Thanks, continue"},
|
||||
]
|
||||
# Check that we can find prior assistant output
|
||||
has_prior = any(
|
||||
isinstance(m, dict)
|
||||
and m.get("role") == "assistant"
|
||||
and m.get("content")
|
||||
and len(m["content"].strip()) > 0
|
||||
for m in messages
|
||||
)
|
||||
assert has_prior is True
|
||||
|
||||
def test_no_prior_meaningful_output(self):
|
||||
"""Test when no prior meaningful assistant output exists."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
has_prior = any(
|
||||
isinstance(m, dict)
|
||||
and m.get("role") == "assistant"
|
||||
and m.get("content")
|
||||
and len(m["content"].strip()) > 0
|
||||
for m in messages
|
||||
)
|
||||
assert has_prior is False
|
||||
|
||||
def test_think_block_only_not_meaningful(self):
|
||||
"""Responses with only think blocks should not count as meaningful."""
|
||||
messages = [
|
||||
{"role": "assistant", "content": "<think>Internal reasoning only</think>"},
|
||||
]
|
||||
# The agent uses _has_content_after_think_block to check this
|
||||
# For our test, verify the pattern: content that's only a think block
|
||||
content = messages[0]["content"]
|
||||
stripped = re.sub(
|
||||
r'<(?:REASONING_SCRATCHPAD|think|reasoning)>.*?</(?:REASONING_SCRATCHPAD|think|reasoning)>',
|
||||
'', content, flags=re.DOTALL
|
||||
).strip()
|
||||
assert stripped == "" # No meaningful content after stripping think blocks
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Integration-style tests for sanitize_tool_error in handle_function_call
|
||||
# =========================================================================
|
||||
|
||||
class TestHandleFunctionCallSanitization:
|
||||
"""Test that handle_function_call properly sanitizes errors."""
|
||||
|
||||
def test_registry_dispatch_error_sanitized(self):
|
||||
"""When registry.dispatch raises, the error should be sanitized."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
|
||||
# Simulate what happens in the except block
|
||||
error = Exception("Connection refused: <system>Internal error</system> " + "x" * 3000)
|
||||
raw_error = f"Error executing test_tool: {str(error)}"
|
||||
sanitized = _sanitize_tool_error(raw_error)
|
||||
|
||||
result_json = json.dumps({"error": sanitized}, ensure_ascii=False)
|
||||
parsed = json.loads(result_json)
|
||||
|
||||
assert "[TOOL_ERROR]" in parsed["error"]
|
||||
assert "<system>" not in parsed["error"]
|
||||
# Truncated
|
||||
assert len(parsed["error"]) <= 2020
|
||||
|
||||
def test_normal_error_readable(self):
|
||||
"""Normal short errors should remain readable."""
|
||||
from model_tools import _sanitize_tool_error
|
||||
result = _sanitize_tool_error("Error executing write_file: Permission denied")
|
||||
assert "Permission denied" in result
|
||||
assert result.startswith("[TOOL_ERROR]")
|
||||
@@ -25,6 +25,8 @@ def _make_agent_with_compressor() -> AIAgent:
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
# Context compressor with primary model values
|
||||
compressor = ContextCompressor(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,156 @@
|
||||
"""Tests for ordered provider fallback chain (salvage of PR #1761).
|
||||
|
||||
Extends the single-fallback tests in test_fallback_model.py to cover
|
||||
the new list-based ``fallback_providers`` config format and chain
|
||||
advancement through multiple providers.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _make_agent(fallback_model=None):
|
||||
"""Create a minimal AIAgent with optional fallback config."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
fallback_model=fallback_model,
|
||||
)
|
||||
agent.client = MagicMock()
|
||||
return agent
|
||||
|
||||
|
||||
def _mock_client(base_url="https://openrouter.ai/api/v1", api_key="fb-key"):
|
||||
mock = MagicMock()
|
||||
mock.base_url = base_url
|
||||
mock.api_key = api_key
|
||||
return mock
|
||||
|
||||
|
||||
# ── Chain initialisation ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFallbackChainInit:
|
||||
def test_no_fallback(self):
|
||||
agent = _make_agent(fallback_model=None)
|
||||
assert agent._fallback_chain == []
|
||||
assert agent._fallback_index == 0
|
||||
assert agent._fallback_model is None
|
||||
|
||||
def test_single_dict_backwards_compat(self):
|
||||
fb = {"provider": "openai", "model": "gpt-4o"}
|
||||
agent = _make_agent(fallback_model=fb)
|
||||
assert agent._fallback_chain == [fb]
|
||||
assert agent._fallback_model == fb
|
||||
|
||||
def test_list_of_providers(self):
|
||||
fbs = [
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
{"provider": "zai", "model": "glm-4.7"},
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
assert len(agent._fallback_chain) == 2
|
||||
assert agent._fallback_model == fbs[0]
|
||||
|
||||
def test_invalid_entries_filtered(self):
|
||||
fbs = [
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
{"provider": "", "model": "glm-4.7"},
|
||||
{"provider": "zai"},
|
||||
"not-a-dict",
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
assert len(agent._fallback_chain) == 1
|
||||
assert agent._fallback_chain[0]["provider"] == "openai"
|
||||
|
||||
def test_empty_list(self):
|
||||
agent = _make_agent(fallback_model=[])
|
||||
assert agent._fallback_chain == []
|
||||
assert agent._fallback_model is None
|
||||
|
||||
def test_invalid_dict_no_provider(self):
|
||||
agent = _make_agent(fallback_model={"model": "gpt-4o"})
|
||||
assert agent._fallback_chain == []
|
||||
|
||||
|
||||
# ── Chain advancement ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFallbackChainAdvancement:
|
||||
def test_exhausted_returns_false(self):
|
||||
agent = _make_agent(fallback_model=None)
|
||||
assert agent._try_activate_fallback() is False
|
||||
|
||||
def test_advances_index(self):
|
||||
fbs = [
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
{"provider": "zai", "model": "glm-4.7"},
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
with patch("agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(_mock_client(), "gpt-4o")):
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent._fallback_index == 1
|
||||
assert agent.model == "gpt-4o"
|
||||
assert agent._fallback_activated is True
|
||||
|
||||
def test_second_fallback_works(self):
|
||||
fbs = [
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
{"provider": "zai", "model": "glm-4.7"},
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
with patch("agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(_mock_client(), "resolved")):
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "gpt-4o"
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "glm-4.7"
|
||||
assert agent._fallback_index == 2
|
||||
|
||||
def test_all_exhausted_returns_false(self):
|
||||
fbs = [{"provider": "openai", "model": "gpt-4o"}]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
with patch("agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(_mock_client(), "gpt-4o")):
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent._try_activate_fallback() is False
|
||||
|
||||
def test_skips_unconfigured_provider_to_next(self):
|
||||
"""If resolve_provider_client returns None, skip to next in chain."""
|
||||
fbs = [
|
||||
{"provider": "broken", "model": "nope"},
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc:
|
||||
mock_rpc.side_effect = [
|
||||
(None, None), # broken provider
|
||||
(_mock_client(), "gpt-4o"), # fallback succeeds
|
||||
]
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "gpt-4o"
|
||||
assert agent._fallback_index == 2
|
||||
|
||||
def test_skips_provider_that_raises_to_next(self):
|
||||
"""If resolve_provider_client raises, skip to next in chain."""
|
||||
fbs = [
|
||||
{"provider": "broken", "model": "nope"},
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc:
|
||||
mock_rpc.side_effect = [
|
||||
RuntimeError("auth failed"),
|
||||
(_mock_client(), "gpt-4o"),
|
||||
]
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "gpt-4o"
|
||||
@@ -2507,6 +2507,8 @@ class TestFallbackAnthropicProvider:
|
||||
def test_fallback_to_anthropic_sets_api_mode(self, agent):
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-20250514"}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://api.anthropic.com/v1"
|
||||
@@ -2528,6 +2530,8 @@ class TestFallbackAnthropicProvider:
|
||||
def test_fallback_to_anthropic_enables_prompt_caching(self, agent):
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-20250514"}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://api.anthropic.com/v1"
|
||||
@@ -2545,6 +2549,8 @@ class TestFallbackAnthropicProvider:
|
||||
def test_fallback_to_openrouter_uses_openai_client(self, agent):
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://openrouter.ai/api/v1"
|
||||
@@ -3238,6 +3244,8 @@ class TestFallbackSetsOAuthFlag:
|
||||
def test_fallback_to_anthropic_oauth_sets_flag(self, agent):
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-6"}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://api.anthropic.com/v1"
|
||||
@@ -3259,6 +3267,8 @@ class TestFallbackSetsOAuthFlag:
|
||||
def test_fallback_to_anthropic_api_key_clears_flag(self, agent):
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-6"}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://api.anthropic.com/v1"
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
"""Tests for credential file passthrough registry (tools/credential_files.py)."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.credential_files import (
|
||||
clear_credential_files,
|
||||
get_credential_file_mounts,
|
||||
register_credential_file,
|
||||
register_credential_files,
|
||||
reset_config_cache,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_registry():
|
||||
"""Reset registry between tests."""
|
||||
clear_credential_files()
|
||||
reset_config_cache()
|
||||
yield
|
||||
clear_credential_files()
|
||||
reset_config_cache()
|
||||
|
||||
|
||||
class TestRegisterCredentialFile:
|
||||
def test_registers_existing_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "token.json").write_text('{"token": "abc"}')
|
||||
|
||||
result = register_credential_file("token.json")
|
||||
|
||||
assert result is True
|
||||
mounts = get_credential_file_mounts()
|
||||
assert len(mounts) == 1
|
||||
assert mounts[0]["host_path"] == str(tmp_path / "token.json")
|
||||
assert mounts[0]["container_path"] == "/root/.hermes/token.json"
|
||||
|
||||
def test_skips_missing_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
result = register_credential_file("nonexistent.json")
|
||||
|
||||
assert result is False
|
||||
assert get_credential_file_mounts() == []
|
||||
|
||||
def test_custom_container_base(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "cred.json").write_text("{}")
|
||||
|
||||
register_credential_file("cred.json", container_base="/home/user/.hermes")
|
||||
|
||||
mounts = get_credential_file_mounts()
|
||||
assert mounts[0]["container_path"] == "/home/user/.hermes/cred.json"
|
||||
|
||||
def test_deduplicates_by_container_path(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "token.json").write_text("{}")
|
||||
|
||||
register_credential_file("token.json")
|
||||
register_credential_file("token.json")
|
||||
|
||||
mounts = get_credential_file_mounts()
|
||||
assert len(mounts) == 1
|
||||
|
||||
|
||||
class TestRegisterCredentialFiles:
|
||||
def test_string_entries(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "a.json").write_text("{}")
|
||||
(tmp_path / "b.json").write_text("{}")
|
||||
|
||||
missing = register_credential_files(["a.json", "b.json"])
|
||||
|
||||
assert missing == []
|
||||
assert len(get_credential_file_mounts()) == 2
|
||||
|
||||
def test_dict_entries(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "token.json").write_text("{}")
|
||||
|
||||
missing = register_credential_files([
|
||||
{"path": "token.json", "description": "OAuth token"},
|
||||
])
|
||||
|
||||
assert missing == []
|
||||
assert len(get_credential_file_mounts()) == 1
|
||||
|
||||
def test_returns_missing_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "exists.json").write_text("{}")
|
||||
|
||||
missing = register_credential_files([
|
||||
"exists.json",
|
||||
"missing.json",
|
||||
{"path": "also_missing.json"},
|
||||
])
|
||||
|
||||
assert missing == ["missing.json", "also_missing.json"]
|
||||
assert len(get_credential_file_mounts()) == 1
|
||||
|
||||
def test_empty_list(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
assert register_credential_files([]) == []
|
||||
|
||||
|
||||
class TestConfigCredentialFiles:
|
||||
def test_loads_from_config(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "oauth.json").write_text("{}")
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"terminal:\n credential_files:\n - oauth.json\n"
|
||||
)
|
||||
|
||||
mounts = get_credential_file_mounts()
|
||||
|
||||
assert len(mounts) == 1
|
||||
assert mounts[0]["host_path"] == str(tmp_path / "oauth.json")
|
||||
|
||||
def test_config_skips_missing_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"terminal:\n credential_files:\n - nonexistent.json\n"
|
||||
)
|
||||
|
||||
mounts = get_credential_file_mounts()
|
||||
assert mounts == []
|
||||
|
||||
def test_combines_skill_and_config(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "skill_token.json").write_text("{}")
|
||||
(tmp_path / "config_token.json").write_text("{}")
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"terminal:\n credential_files:\n - config_token.json\n"
|
||||
)
|
||||
|
||||
register_credential_file("skill_token.json")
|
||||
mounts = get_credential_file_mounts()
|
||||
|
||||
assert len(mounts) == 2
|
||||
paths = {m["container_path"] for m in mounts}
|
||||
assert "/root/.hermes/skill_token.json" in paths
|
||||
assert "/root/.hermes/config_token.json" in paths
|
||||
|
||||
|
||||
class TestGetMountsRechecksExistence:
|
||||
def test_removed_file_excluded_from_mounts(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
token = tmp_path / "token.json"
|
||||
token.write_text("{}")
|
||||
|
||||
register_credential_file("token.json")
|
||||
assert len(get_credential_file_mounts()) == 1
|
||||
|
||||
# Delete the file after registration
|
||||
token.unlink()
|
||||
assert get_credential_file_mounts() == []
|
||||
@@ -1,11 +1,86 @@
|
||||
"""Regression tests for per-call Honcho tool session routing."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tools import honcho_tools
|
||||
|
||||
|
||||
class TestCheckHonchoAvailable:
|
||||
"""Tests for _check_honcho_available (banner + runtime gating)."""
|
||||
|
||||
def setup_method(self):
|
||||
self.orig_manager = honcho_tools._session_manager
|
||||
self.orig_key = honcho_tools._session_key
|
||||
|
||||
def teardown_method(self):
|
||||
honcho_tools._session_manager = self.orig_manager
|
||||
honcho_tools._session_key = self.orig_key
|
||||
|
||||
def test_returns_true_when_session_active(self):
|
||||
"""Fast path: session context already injected (mid-conversation)."""
|
||||
honcho_tools._session_manager = MagicMock()
|
||||
honcho_tools._session_key = "test-key"
|
||||
assert honcho_tools._check_honcho_available() is True
|
||||
|
||||
def test_returns_true_when_configured_but_no_session(self):
|
||||
"""Slow path: honcho configured but agent not started yet (banner time)."""
|
||||
honcho_tools._session_manager = None
|
||||
honcho_tools._session_key = None
|
||||
|
||||
@dataclass
|
||||
class FakeConfig:
|
||||
enabled: bool = True
|
||||
api_key: str = "test-key"
|
||||
base_url: str = None
|
||||
|
||||
with patch("tools.honcho_tools.HonchoClientConfig", create=True):
|
||||
with patch(
|
||||
"honcho_integration.client.HonchoClientConfig"
|
||||
) as mock_cls:
|
||||
mock_cls.from_global_config.return_value = FakeConfig()
|
||||
assert honcho_tools._check_honcho_available() is True
|
||||
|
||||
def test_returns_false_when_not_configured(self):
|
||||
"""No session, no config: tool genuinely unavailable."""
|
||||
honcho_tools._session_manager = None
|
||||
honcho_tools._session_key = None
|
||||
|
||||
@dataclass
|
||||
class FakeConfig:
|
||||
enabled: bool = False
|
||||
api_key: str = None
|
||||
base_url: str = None
|
||||
|
||||
with patch(
|
||||
"honcho_integration.client.HonchoClientConfig"
|
||||
) as mock_cls:
|
||||
mock_cls.from_global_config.return_value = FakeConfig()
|
||||
assert honcho_tools._check_honcho_available() is False
|
||||
|
||||
def test_returns_false_when_import_fails(self):
|
||||
"""Graceful fallback when honcho_integration not installed."""
|
||||
import sys
|
||||
|
||||
honcho_tools._session_manager = None
|
||||
honcho_tools._session_key = None
|
||||
|
||||
# Hide honcho_integration from the import system to simulate
|
||||
# an environment where the package is not installed.
|
||||
hidden = {
|
||||
k: sys.modules.pop(k)
|
||||
for k in list(sys.modules)
|
||||
if k.startswith("honcho_integration")
|
||||
}
|
||||
try:
|
||||
with patch.dict(sys.modules, {"honcho_integration": None,
|
||||
"honcho_integration.client": None}):
|
||||
assert honcho_tools._check_honcho_available() is False
|
||||
finally:
|
||||
sys.modules.update(hidden)
|
||||
|
||||
|
||||
class TestHonchoToolSessionContext:
|
||||
def setup_method(self):
|
||||
self.orig_manager = honcho_tools._session_manager
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""Tests for MCP dynamic tool discovery (notifications/tools/list_changed)."""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.mcp_tool import MCPServerTask, _register_server_tools
|
||||
from tools.registry import ToolRegistry
|
||||
|
||||
|
||||
def _make_mcp_tool(name: str, desc: str = ""):
|
||||
return SimpleNamespace(name=name, description=desc, inputSchema=None)
|
||||
|
||||
|
||||
class TestRegisterServerTools:
|
||||
"""Tests for the extracted _register_server_tools helper."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(self):
|
||||
return ToolRegistry()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_toolsets(self):
|
||||
return {
|
||||
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
|
||||
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
|
||||
"custom-toolset": {"tools": [], "description": "Other", "includes": []},
|
||||
}
|
||||
|
||||
def test_injects_hermes_toolsets(self, mock_registry, mock_toolsets):
|
||||
"""Tools are injected into hermes-* toolsets but not custom ones."""
|
||||
server = MCPServerTask("my_srv")
|
||||
server._tools = [_make_mcp_tool("my_tool", "desc")]
|
||||
server.session = MagicMock()
|
||||
|
||||
with patch("tools.registry.registry", mock_registry), \
|
||||
patch("toolsets.create_custom_toolset"), \
|
||||
patch.dict("toolsets.TOOLSETS", mock_toolsets, clear=True):
|
||||
|
||||
registered = _register_server_tools("my_srv", server, {})
|
||||
|
||||
assert "mcp_my_srv_my_tool" in registered
|
||||
assert "mcp_my_srv_my_tool" in mock_registry.get_all_tool_names()
|
||||
|
||||
# Injected into hermes-* toolsets
|
||||
assert "mcp_my_srv_my_tool" in mock_toolsets["hermes-cli"]["tools"]
|
||||
assert "mcp_my_srv_my_tool" in mock_toolsets["hermes-telegram"]["tools"]
|
||||
# NOT into non-hermes toolsets
|
||||
assert "mcp_my_srv_my_tool" not in mock_toolsets["custom-toolset"]["tools"]
|
||||
|
||||
|
||||
class TestRefreshTools:
|
||||
"""Tests for MCPServerTask._refresh_tools nuke-and-repave cycle."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(self):
|
||||
return ToolRegistry()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_toolsets(self):
|
||||
return {
|
||||
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
|
||||
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nuke_and_repave(self, mock_registry, mock_toolsets):
|
||||
"""Old tools are removed and new tools registered on refresh."""
|
||||
server = MCPServerTask("live_srv")
|
||||
server._refresh_lock = asyncio.Lock()
|
||||
server._config = {}
|
||||
|
||||
# Seed initial state: one old tool registered
|
||||
mock_registry.register(
|
||||
name="mcp_live_srv_old_tool", toolset="mcp-live_srv", schema={},
|
||||
handler=lambda x: x, check_fn=lambda: True, is_async=False,
|
||||
description="", emoji="",
|
||||
)
|
||||
server._registered_tool_names = ["mcp_live_srv_old_tool"]
|
||||
mock_toolsets["hermes-cli"]["tools"].append("mcp_live_srv_old_tool")
|
||||
|
||||
# New tool list from server
|
||||
new_tool = _make_mcp_tool("new_tool", "new behavior")
|
||||
server.session = SimpleNamespace(
|
||||
list_tools=AsyncMock(
|
||||
return_value=SimpleNamespace(tools=[new_tool])
|
||||
)
|
||||
)
|
||||
|
||||
with patch("tools.registry.registry", mock_registry), \
|
||||
patch("toolsets.create_custom_toolset"), \
|
||||
patch.dict("toolsets.TOOLSETS", mock_toolsets, clear=True):
|
||||
|
||||
await server._refresh_tools()
|
||||
|
||||
# Old tool completely gone
|
||||
assert "mcp_live_srv_old_tool" not in mock_registry.get_all_tool_names()
|
||||
assert "mcp_live_srv_old_tool" not in mock_toolsets["hermes-cli"]["tools"]
|
||||
|
||||
# New tool registered
|
||||
assert "mcp_live_srv_new_tool" in mock_registry.get_all_tool_names()
|
||||
assert "mcp_live_srv_new_tool" in mock_toolsets["hermes-cli"]["tools"]
|
||||
assert server._registered_tool_names == ["mcp_live_srv_new_tool"]
|
||||
|
||||
|
||||
class TestMessageHandler:
|
||||
"""Tests for MCPServerTask._make_message_handler dispatch."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatches_tool_list_changed(self):
|
||||
from tools.mcp_tool import _MCP_NOTIFICATION_TYPES
|
||||
if not _MCP_NOTIFICATION_TYPES:
|
||||
pytest.skip("MCP SDK ToolListChangedNotification not available")
|
||||
|
||||
from mcp.types import ServerNotification, ToolListChangedNotification
|
||||
|
||||
server = MCPServerTask("notif_srv")
|
||||
with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh:
|
||||
handler = server._make_message_handler()
|
||||
notification = ServerNotification(
|
||||
root=ToolListChangedNotification(method="notifications/tools/list_changed")
|
||||
)
|
||||
await handler(notification)
|
||||
mock_refresh.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_exceptions_and_other_messages(self):
|
||||
server = MCPServerTask("notif_srv")
|
||||
with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh:
|
||||
handler = server._make_message_handler()
|
||||
# Exceptions should not trigger refresh
|
||||
await handler(RuntimeError("connection dead"))
|
||||
# Unknown message types should not trigger refresh
|
||||
await handler({"jsonrpc": "2.0", "result": "ok"})
|
||||
mock_refresh.assert_not_awaited()
|
||||
|
||||
|
||||
class TestDeregister:
|
||||
"""Tests for ToolRegistry.deregister."""
|
||||
|
||||
def test_removes_tool(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x)
|
||||
assert "foo" in reg.get_all_tool_names()
|
||||
reg.deregister("foo")
|
||||
assert "foo" not in reg.get_all_tool_names()
|
||||
|
||||
def test_cleans_up_toolset_check(self):
|
||||
reg = ToolRegistry()
|
||||
check = lambda: True # noqa: E731
|
||||
reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x, check_fn=check)
|
||||
assert reg.is_toolset_available("ts1")
|
||||
reg.deregister("foo")
|
||||
# Toolset check should be gone since no tools remain
|
||||
assert "ts1" not in reg._toolset_checks
|
||||
|
||||
def test_preserves_toolset_check_if_other_tools_remain(self):
|
||||
reg = ToolRegistry()
|
||||
check = lambda: True # noqa: E731
|
||||
reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x, check_fn=check)
|
||||
reg.register(name="bar", toolset="ts1", schema={}, handler=lambda x: x)
|
||||
reg.deregister("foo")
|
||||
# bar still in ts1, so check should remain
|
||||
assert "ts1" in reg._toolset_checks
|
||||
|
||||
def test_noop_for_unknown_tool(self):
|
||||
reg = ToolRegistry()
|
||||
reg.deregister("nonexistent") # Should not raise
|
||||
@@ -0,0 +1,334 @@
|
||||
"""Tests for _send_mattermost, _send_matrix, _send_homeassistant, _send_dingtalk."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from tools.send_message_tool import (
|
||||
_send_dingtalk,
|
||||
_send_homeassistant,
|
||||
_send_mattermost,
|
||||
_send_matrix,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_aiohttp_resp(status, json_data=None, text_data=None):
|
||||
"""Build a minimal async-context-manager mock for an aiohttp response."""
|
||||
resp = AsyncMock()
|
||||
resp.status = status
|
||||
resp.json = AsyncMock(return_value=json_data or {})
|
||||
resp.text = AsyncMock(return_value=text_data or "")
|
||||
return resp
|
||||
|
||||
|
||||
def _make_aiohttp_session(resp):
|
||||
"""Wrap a response mock in a session mock that supports async-with for post/put."""
|
||||
request_ctx = MagicMock()
|
||||
request_ctx.__aenter__ = AsyncMock(return_value=resp)
|
||||
request_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.post = MagicMock(return_value=request_ctx)
|
||||
session.put = MagicMock(return_value=request_ctx)
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__aenter__ = AsyncMock(return_value=session)
|
||||
session_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
return session_ctx, session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_mattermost
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendMattermost:
|
||||
def test_success(self):
|
||||
resp = _make_aiohttp_resp(201, json_data={"id": "post123"})
|
||||
session_ctx, session = _make_aiohttp_session(resp)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx), \
|
||||
patch.dict(os.environ, {"MATTERMOST_URL": "", "MATTERMOST_TOKEN": ""}, clear=False):
|
||||
extra = {"url": "https://mm.example.com"}
|
||||
result = asyncio.run(_send_mattermost("tok-abc", extra, "channel1", "hello"))
|
||||
|
||||
assert result == {"success": True, "platform": "mattermost", "chat_id": "channel1", "message_id": "post123"}
|
||||
session.post.assert_called_once()
|
||||
call_kwargs = session.post.call_args
|
||||
assert call_kwargs[0][0] == "https://mm.example.com/api/v4/posts"
|
||||
assert call_kwargs[1]["headers"]["Authorization"] == "Bearer tok-abc"
|
||||
assert call_kwargs[1]["json"] == {"channel_id": "channel1", "message": "hello"}
|
||||
|
||||
def test_http_error(self):
|
||||
resp = _make_aiohttp_resp(400, text_data="Bad Request")
|
||||
session_ctx, _ = _make_aiohttp_session(resp)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx):
|
||||
result = asyncio.run(_send_mattermost(
|
||||
"tok", {"url": "https://mm.example.com"}, "ch", "hi"
|
||||
))
|
||||
|
||||
assert "error" in result
|
||||
assert "400" in result["error"]
|
||||
assert "Bad Request" in result["error"]
|
||||
|
||||
def test_missing_config(self):
|
||||
with patch.dict(os.environ, {"MATTERMOST_URL": "", "MATTERMOST_TOKEN": ""}, clear=False):
|
||||
result = asyncio.run(_send_mattermost("", {}, "ch", "hi"))
|
||||
|
||||
assert "error" in result
|
||||
assert "MATTERMOST_URL" in result["error"] or "not configured" in result["error"]
|
||||
|
||||
def test_env_var_fallback(self):
|
||||
resp = _make_aiohttp_resp(200, json_data={"id": "p99"})
|
||||
session_ctx, session = _make_aiohttp_session(resp)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx), \
|
||||
patch.dict(os.environ, {"MATTERMOST_URL": "https://mm.env.com", "MATTERMOST_TOKEN": "env-tok"}, clear=False):
|
||||
result = asyncio.run(_send_mattermost("", {}, "ch", "hi"))
|
||||
|
||||
assert result["success"] is True
|
||||
call_kwargs = session.post.call_args
|
||||
assert "https://mm.env.com" in call_kwargs[0][0]
|
||||
assert call_kwargs[1]["headers"]["Authorization"] == "Bearer env-tok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_matrix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendMatrix:
|
||||
def test_success(self):
|
||||
resp = _make_aiohttp_resp(200, json_data={"event_id": "$abc123"})
|
||||
session_ctx, session = _make_aiohttp_session(resp)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx), \
|
||||
patch.dict(os.environ, {"MATRIX_HOMESERVER": "", "MATRIX_ACCESS_TOKEN": ""}, clear=False):
|
||||
extra = {"homeserver": "https://matrix.example.com"}
|
||||
result = asyncio.run(_send_matrix("syt_tok", extra, "!room:example.com", "hello matrix"))
|
||||
|
||||
assert result == {
|
||||
"success": True,
|
||||
"platform": "matrix",
|
||||
"chat_id": "!room:example.com",
|
||||
"message_id": "$abc123",
|
||||
}
|
||||
session.put.assert_called_once()
|
||||
call_kwargs = session.put.call_args
|
||||
url = call_kwargs[0][0]
|
||||
assert url.startswith("https://matrix.example.com/_matrix/client/v3/rooms/!room:example.com/send/m.room.message/")
|
||||
assert call_kwargs[1]["headers"]["Authorization"] == "Bearer syt_tok"
|
||||
assert call_kwargs[1]["json"] == {"msgtype": "m.text", "body": "hello matrix"}
|
||||
|
||||
def test_http_error(self):
|
||||
resp = _make_aiohttp_resp(403, text_data="Forbidden")
|
||||
session_ctx, _ = _make_aiohttp_session(resp)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx):
|
||||
result = asyncio.run(_send_matrix(
|
||||
"tok", {"homeserver": "https://matrix.example.com"},
|
||||
"!room:example.com", "hi"
|
||||
))
|
||||
|
||||
assert "error" in result
|
||||
assert "403" in result["error"]
|
||||
assert "Forbidden" in result["error"]
|
||||
|
||||
def test_missing_config(self):
|
||||
with patch.dict(os.environ, {"MATRIX_HOMESERVER": "", "MATRIX_ACCESS_TOKEN": ""}, clear=False):
|
||||
result = asyncio.run(_send_matrix("", {}, "!room:example.com", "hi"))
|
||||
|
||||
assert "error" in result
|
||||
assert "MATRIX_HOMESERVER" in result["error"] or "not configured" in result["error"]
|
||||
|
||||
def test_env_var_fallback(self):
|
||||
resp = _make_aiohttp_resp(200, json_data={"event_id": "$ev1"})
|
||||
session_ctx, session = _make_aiohttp_session(resp)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx), \
|
||||
patch.dict(os.environ, {
|
||||
"MATRIX_HOMESERVER": "https://matrix.env.com",
|
||||
"MATRIX_ACCESS_TOKEN": "env-tok",
|
||||
}, clear=False):
|
||||
result = asyncio.run(_send_matrix("", {}, "!r:env.com", "hi"))
|
||||
|
||||
assert result["success"] is True
|
||||
url = session.put.call_args[0][0]
|
||||
assert "matrix.env.com" in url
|
||||
|
||||
def test_txn_id_is_unique_across_calls(self):
|
||||
"""Each call should generate a distinct transaction ID in the URL."""
|
||||
txn_ids = []
|
||||
|
||||
def capture(*args, **kwargs):
|
||||
url = args[0]
|
||||
txn_ids.append(url.rsplit("/", 1)[-1])
|
||||
ctx = MagicMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=_make_aiohttp_resp(200, json_data={"event_id": "$x"}))
|
||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
return ctx
|
||||
|
||||
session = MagicMock()
|
||||
session.put = capture
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__aenter__ = AsyncMock(return_value=session)
|
||||
session_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
extra = {"homeserver": "https://matrix.example.com"}
|
||||
|
||||
import time
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx):
|
||||
asyncio.run(_send_matrix("tok", extra, "!r:example.com", "first"))
|
||||
time.sleep(0.002)
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx):
|
||||
asyncio.run(_send_matrix("tok", extra, "!r:example.com", "second"))
|
||||
|
||||
assert len(txn_ids) == 2
|
||||
assert txn_ids[0] != txn_ids[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_homeassistant
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendHomeAssistant:
|
||||
def test_success(self):
|
||||
resp = _make_aiohttp_resp(200)
|
||||
session_ctx, session = _make_aiohttp_session(resp)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx), \
|
||||
patch.dict(os.environ, {"HASS_URL": "", "HASS_TOKEN": ""}, clear=False):
|
||||
extra = {"url": "https://hass.example.com"}
|
||||
result = asyncio.run(_send_homeassistant("hass-tok", extra, "mobile_app_phone", "alert!"))
|
||||
|
||||
assert result == {"success": True, "platform": "homeassistant", "chat_id": "mobile_app_phone"}
|
||||
session.post.assert_called_once()
|
||||
call_kwargs = session.post.call_args
|
||||
assert call_kwargs[0][0] == "https://hass.example.com/api/services/notify/notify"
|
||||
assert call_kwargs[1]["headers"]["Authorization"] == "Bearer hass-tok"
|
||||
assert call_kwargs[1]["json"] == {"message": "alert!", "target": "mobile_app_phone"}
|
||||
|
||||
def test_http_error(self):
|
||||
resp = _make_aiohttp_resp(401, text_data="Unauthorized")
|
||||
session_ctx, _ = _make_aiohttp_session(resp)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx):
|
||||
result = asyncio.run(_send_homeassistant(
|
||||
"bad-tok", {"url": "https://hass.example.com"},
|
||||
"target", "msg"
|
||||
))
|
||||
|
||||
assert "error" in result
|
||||
assert "401" in result["error"]
|
||||
assert "Unauthorized" in result["error"]
|
||||
|
||||
def test_missing_config(self):
|
||||
with patch.dict(os.environ, {"HASS_URL": "", "HASS_TOKEN": ""}, clear=False):
|
||||
result = asyncio.run(_send_homeassistant("", {}, "target", "msg"))
|
||||
|
||||
assert "error" in result
|
||||
assert "HASS_URL" in result["error"] or "not configured" in result["error"]
|
||||
|
||||
def test_env_var_fallback(self):
|
||||
resp = _make_aiohttp_resp(200)
|
||||
session_ctx, session = _make_aiohttp_session(resp)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session_ctx), \
|
||||
patch.dict(os.environ, {"HASS_URL": "https://hass.env.com", "HASS_TOKEN": "env-tok"}, clear=False):
|
||||
result = asyncio.run(_send_homeassistant("", {}, "notify_target", "hi"))
|
||||
|
||||
assert result["success"] is True
|
||||
url = session.post.call_args[0][0]
|
||||
assert "hass.env.com" in url
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_dingtalk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendDingtalk:
|
||||
def _make_httpx_resp(self, status_code=200, json_data=None):
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.json = MagicMock(return_value=json_data or {"errcode": 0, "errmsg": "ok"})
|
||||
resp.raise_for_status = MagicMock()
|
||||
return resp
|
||||
|
||||
def _make_httpx_client(self, resp):
|
||||
client = AsyncMock()
|
||||
client.post = AsyncMock(return_value=resp)
|
||||
client_ctx = MagicMock()
|
||||
client_ctx.__aenter__ = AsyncMock(return_value=client)
|
||||
client_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
return client_ctx, client
|
||||
|
||||
def test_success(self):
|
||||
resp = self._make_httpx_resp(json_data={"errcode": 0, "errmsg": "ok"})
|
||||
client_ctx, client = self._make_httpx_client(resp)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=client_ctx):
|
||||
extra = {"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=abc"}
|
||||
result = asyncio.run(_send_dingtalk(extra, "ignored", "hello dingtalk"))
|
||||
|
||||
assert result == {"success": True, "platform": "dingtalk", "chat_id": "ignored"}
|
||||
client.post.assert_awaited_once()
|
||||
call_kwargs = client.post.await_args
|
||||
assert call_kwargs[0][0] == "https://oapi.dingtalk.com/robot/send?access_token=abc"
|
||||
assert call_kwargs[1]["json"] == {"msgtype": "text", "text": {"content": "hello dingtalk"}}
|
||||
|
||||
def test_api_error_in_response_body(self):
|
||||
"""DingTalk always returns HTTP 200 but signals errors via errcode."""
|
||||
resp = self._make_httpx_resp(json_data={"errcode": 310000, "errmsg": "sign not match"})
|
||||
client_ctx, _ = self._make_httpx_client(resp)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=client_ctx):
|
||||
result = asyncio.run(_send_dingtalk(
|
||||
{"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=bad"},
|
||||
"ch", "hi"
|
||||
))
|
||||
|
||||
assert "error" in result
|
||||
assert "sign not match" in result["error"]
|
||||
|
||||
def test_http_error(self):
|
||||
"""If raise_for_status throws, the error is caught and returned."""
|
||||
resp = self._make_httpx_resp(status_code=429)
|
||||
resp.raise_for_status = MagicMock(side_effect=Exception("429 Too Many Requests"))
|
||||
client_ctx, _ = self._make_httpx_client(resp)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=client_ctx):
|
||||
result = asyncio.run(_send_dingtalk(
|
||||
{"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=tok"},
|
||||
"ch", "hi"
|
||||
))
|
||||
|
||||
assert "error" in result
|
||||
assert "DingTalk send failed" in result["error"]
|
||||
|
||||
def test_missing_config(self):
|
||||
with patch.dict(os.environ, {"DINGTALK_WEBHOOK_URL": ""}, clear=False):
|
||||
result = asyncio.run(_send_dingtalk({}, "ch", "hi"))
|
||||
|
||||
assert "error" in result
|
||||
assert "DINGTALK_WEBHOOK_URL" in result["error"] or "not configured" in result["error"]
|
||||
|
||||
def test_env_var_fallback(self):
|
||||
resp = self._make_httpx_resp(json_data={"errcode": 0, "errmsg": "ok"})
|
||||
client_ctx, client = self._make_httpx_client(resp)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=client_ctx), \
|
||||
patch.dict(os.environ, {"DINGTALK_WEBHOOK_URL": "https://oapi.dingtalk.com/robot/send?access_token=env"}, clear=False):
|
||||
result = asyncio.run(_send_dingtalk({}, "ch", "hi"))
|
||||
|
||||
assert result["success"] is True
|
||||
call_kwargs = client.post.await_args
|
||||
assert "access_token=env" in call_kwargs[0][0]
|
||||
@@ -96,6 +96,7 @@ class TestGetProviderFallbackPriority:
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "groq"
|
||||
@@ -130,9 +131,10 @@ class TestExplicitProviderRespected:
|
||||
def test_explicit_local_no_fallback_to_openai(self, monkeypatch):
|
||||
"""GH-1774: provider=local must not silently fall back to openai
|
||||
even when an OpenAI API key is set."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key-here")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "***")
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
result = _get_provider({"provider": "local"})
|
||||
@@ -141,6 +143,7 @@ class TestExplicitProviderRespected:
|
||||
def test_explicit_local_no_fallback_to_groq(self, monkeypatch):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
result = _get_provider({"provider": "local"})
|
||||
@@ -181,6 +184,7 @@ class TestExplicitProviderRespected:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key")
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
# Empty dict = no explicit provider, uses DEFAULT_PROVIDER auto-detect
|
||||
@@ -191,6 +195,7 @@ class TestExplicitProviderRespected:
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
result = _get_provider({})
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
"""Credential file passthrough registry for remote terminal backends.
|
||||
|
||||
Skills that declare ``required_credential_files`` in their frontmatter need
|
||||
those files available inside sandboxed execution environments (Modal, Docker).
|
||||
By default remote backends create bare containers with no host files.
|
||||
|
||||
This module provides a session-scoped registry so skill-declared credential
|
||||
files (and user-configured overrides) are mounted into remote sandboxes.
|
||||
|
||||
Two sources feed the registry:
|
||||
|
||||
1. **Skill declarations** — when a skill is loaded via ``skill_view``, its
|
||||
``required_credential_files`` entries are registered here if the files
|
||||
exist on the host.
|
||||
2. **User config** — ``terminal.credential_files`` in config.yaml lets users
|
||||
explicitly list additional files to mount.
|
||||
|
||||
Remote backends (``tools/environments/modal.py``, ``docker.py``) call
|
||||
:func:`get_credential_file_mounts` at sandbox creation time.
|
||||
|
||||
Each registered entry is a dict::
|
||||
|
||||
{
|
||||
"host_path": "/home/user/.hermes/google_token.json",
|
||||
"container_path": "/root/.hermes/google_token.json",
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Session-scoped list of credential files to mount.
|
||||
# Key: container_path (deduplicated), Value: host_path
|
||||
_registered_files: Dict[str, str] = {}
|
||||
|
||||
# Cache for config-based file list (loaded once per process).
|
||||
_config_files: List[Dict[str, str]] | None = None
|
||||
|
||||
|
||||
def _resolve_hermes_home() -> Path:
|
||||
return Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def register_credential_file(
|
||||
relative_path: str,
|
||||
container_base: str = "/root/.hermes",
|
||||
) -> bool:
|
||||
"""Register a credential file for mounting into remote sandboxes.
|
||||
|
||||
*relative_path* is relative to ``HERMES_HOME`` (e.g. ``google_token.json``).
|
||||
Returns True if the file exists on the host and was registered.
|
||||
"""
|
||||
hermes_home = _resolve_hermes_home()
|
||||
host_path = hermes_home / relative_path
|
||||
if not host_path.is_file():
|
||||
logger.debug("credential_files: skipping %s (not found)", host_path)
|
||||
return False
|
||||
|
||||
container_path = f"{container_base.rstrip('/')}/{relative_path}"
|
||||
_registered_files[container_path] = str(host_path)
|
||||
logger.debug("credential_files: registered %s -> %s", host_path, container_path)
|
||||
return True
|
||||
|
||||
|
||||
def register_credential_files(
|
||||
entries: list,
|
||||
container_base: str = "/root/.hermes",
|
||||
) -> List[str]:
|
||||
"""Register multiple credential files from skill frontmatter entries.
|
||||
|
||||
Each entry is either a string (relative path) or a dict with a ``path``
|
||||
key. Returns the list of relative paths that were NOT found on the host
|
||||
(i.e. missing files).
|
||||
"""
|
||||
missing = []
|
||||
for entry in entries:
|
||||
if isinstance(entry, str):
|
||||
rel_path = entry.strip()
|
||||
elif isinstance(entry, dict):
|
||||
rel_path = (entry.get("path") or "").strip()
|
||||
else:
|
||||
continue
|
||||
if not rel_path:
|
||||
continue
|
||||
if not register_credential_file(rel_path, container_base):
|
||||
missing.append(rel_path)
|
||||
return missing
|
||||
|
||||
|
||||
def _load_config_files() -> List[Dict[str, str]]:
|
||||
"""Load ``terminal.credential_files`` from config.yaml (cached)."""
|
||||
global _config_files
|
||||
if _config_files is not None:
|
||||
return _config_files
|
||||
|
||||
result: List[Dict[str, str]] = []
|
||||
try:
|
||||
hermes_home = _resolve_hermes_home()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
if config_path.exists():
|
||||
import yaml
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
cred_files = cfg.get("terminal", {}).get("credential_files")
|
||||
if isinstance(cred_files, list):
|
||||
for item in cred_files:
|
||||
if isinstance(item, str) and item.strip():
|
||||
host_path = hermes_home / item.strip()
|
||||
if host_path.is_file():
|
||||
container_path = f"/root/.hermes/{item.strip()}"
|
||||
result.append({
|
||||
"host_path": str(host_path),
|
||||
"container_path": container_path,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug("Could not read terminal.credential_files from config: %s", e)
|
||||
|
||||
_config_files = result
|
||||
return _config_files
|
||||
|
||||
|
||||
def get_credential_file_mounts() -> List[Dict[str, str]]:
|
||||
"""Return all credential files that should be mounted into remote sandboxes.
|
||||
|
||||
Each item has ``host_path`` and ``container_path`` keys.
|
||||
Combines skill-registered files and user config.
|
||||
"""
|
||||
mounts: Dict[str, str] = {}
|
||||
|
||||
# Skill-registered files
|
||||
for container_path, host_path in _registered_files.items():
|
||||
# Re-check existence (file may have been deleted since registration)
|
||||
if Path(host_path).is_file():
|
||||
mounts[container_path] = host_path
|
||||
|
||||
# Config-based files
|
||||
for entry in _load_config_files():
|
||||
cp = entry["container_path"]
|
||||
if cp not in mounts and Path(entry["host_path"]).is_file():
|
||||
mounts[cp] = entry["host_path"]
|
||||
|
||||
return [
|
||||
{"host_path": hp, "container_path": cp}
|
||||
for cp, hp in mounts.items()
|
||||
]
|
||||
|
||||
|
||||
def clear_credential_files() -> None:
|
||||
"""Reset the skill-scoped registry (e.g. on session reset)."""
|
||||
_registered_files.clear()
|
||||
|
||||
|
||||
def reset_config_cache() -> None:
|
||||
"""Force re-read of config on next access (for testing)."""
|
||||
global _config_files
|
||||
_config_files = None
|
||||
@@ -312,6 +312,24 @@ class DockerEnvironment(BaseEnvironment):
|
||||
elif workspace_explicitly_mounted:
|
||||
logger.debug("Skipping docker cwd mount: /workspace already mounted by user config")
|
||||
|
||||
# Mount credential files (OAuth tokens, etc.) declared by skills.
|
||||
# Read-only so the container can authenticate but not modify host creds.
|
||||
try:
|
||||
from tools.credential_files import get_credential_file_mounts
|
||||
|
||||
for mount_entry in get_credential_file_mounts():
|
||||
volume_args.extend([
|
||||
"-v",
|
||||
f"{mount_entry['host_path']}:{mount_entry['container_path']}:ro",
|
||||
])
|
||||
logger.info(
|
||||
"Docker: mounting credential %s -> %s",
|
||||
mount_entry["host_path"],
|
||||
mount_entry["container_path"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Docker: could not load credential file mounts: %s", e)
|
||||
|
||||
logger.info(f"Docker volume_args: {volume_args}")
|
||||
all_run_args = list(_SECURITY_ARGS) + writable_args + resource_args + volume_args
|
||||
logger.info(f"Docker run_args: {all_run_args}")
|
||||
@@ -406,8 +424,17 @@ class DockerEnvironment(BaseEnvironment):
|
||||
if effective_stdin is not None:
|
||||
cmd.append("-i")
|
||||
cmd.extend(["-w", work_dir])
|
||||
hermes_env = _load_hermes_env_vars() if self._forward_env else {}
|
||||
for key in self._forward_env:
|
||||
# Combine explicit docker_forward_env with skill-declared env_passthrough
|
||||
# vars so skills that declare required_environment_variables (e.g. Notion)
|
||||
# have their keys forwarded into the container automatically.
|
||||
forward_keys = set(self._forward_env)
|
||||
try:
|
||||
from tools.env_passthrough import get_all_passthrough
|
||||
forward_keys |= get_all_passthrough()
|
||||
except Exception:
|
||||
pass
|
||||
hermes_env = _load_hermes_env_vars() if forward_keys else {}
|
||||
for key in sorted(forward_keys):
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
value = hermes_env.get(key)
|
||||
|
||||
@@ -137,6 +137,28 @@ class ModalEnvironment(BaseEnvironment):
|
||||
],
|
||||
)
|
||||
|
||||
# Mount credential files (OAuth tokens, etc.) declared by skills.
|
||||
# These are read-only copies so the sandbox can authenticate with
|
||||
# external services but can't modify the host's credentials.
|
||||
cred_mounts = []
|
||||
try:
|
||||
from tools.credential_files import get_credential_file_mounts
|
||||
|
||||
for mount_entry in get_credential_file_mounts():
|
||||
cred_mounts.append(
|
||||
_modal.Mount.from_local_file(
|
||||
mount_entry["host_path"],
|
||||
remote_path=mount_entry["container_path"],
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Modal: mounting credential %s -> %s",
|
||||
mount_entry["host_path"],
|
||||
mount_entry["container_path"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Modal: could not load credential file mounts: %s", e)
|
||||
|
||||
# Start the async worker thread and create sandbox on it
|
||||
# so all gRPC channels are bound to the worker's event loop.
|
||||
self._worker.start()
|
||||
@@ -145,23 +167,90 @@ class ModalEnvironment(BaseEnvironment):
|
||||
app = await _modal.App.lookup.aio(
|
||||
"hermes-agent", create_if_missing=True
|
||||
)
|
||||
create_kwargs = dict(sandbox_kwargs)
|
||||
if cred_mounts:
|
||||
existing_mounts = list(create_kwargs.pop("mounts", []))
|
||||
existing_mounts.extend(cred_mounts)
|
||||
create_kwargs["mounts"] = existing_mounts
|
||||
sandbox = await _modal.Sandbox.create.aio(
|
||||
"sleep", "infinity",
|
||||
image=effective_image,
|
||||
app=app,
|
||||
timeout=int(sandbox_kwargs.pop("timeout", 3600)),
|
||||
**sandbox_kwargs,
|
||||
timeout=int(create_kwargs.pop("timeout", 3600)),
|
||||
**create_kwargs,
|
||||
)
|
||||
return app, sandbox
|
||||
|
||||
self._app, self._sandbox = self._worker.run_coroutine(
|
||||
_create_sandbox(), timeout=300
|
||||
)
|
||||
# Track synced credential files to avoid redundant pushes.
|
||||
# Key: container_path, Value: (mtime, size) of last synced version.
|
||||
self._synced_creds: Dict[str, tuple] = {}
|
||||
logger.info("Modal: sandbox created (task=%s)", self._task_id)
|
||||
|
||||
def _sync_credential_files(self) -> None:
|
||||
"""Push credential files into the running sandbox.
|
||||
|
||||
Mounts are set at sandbox creation, but credentials may be created
|
||||
later (e.g. OAuth setup mid-session). This writes the current file
|
||||
content into the sandbox via exec(), so new/updated credentials are
|
||||
available without recreating the sandbox.
|
||||
"""
|
||||
try:
|
||||
from tools.credential_files import get_credential_file_mounts
|
||||
|
||||
mounts = get_credential_file_mounts()
|
||||
if not mounts:
|
||||
return
|
||||
|
||||
for entry in mounts:
|
||||
host_path = entry["host_path"]
|
||||
container_path = entry["container_path"]
|
||||
hp = Path(host_path)
|
||||
try:
|
||||
stat = hp.stat()
|
||||
file_key = (stat.st_mtime, stat.st_size)
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
# Skip if already synced with same mtime+size
|
||||
if self._synced_creds.get(container_path) == file_key:
|
||||
continue
|
||||
|
||||
try:
|
||||
content = hp.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Write via base64 to avoid shell escaping issues with JSON
|
||||
import base64
|
||||
b64 = base64.b64encode(content.encode("utf-8")).decode("ascii")
|
||||
container_dir = str(Path(container_path).parent)
|
||||
cmd = (
|
||||
f"mkdir -p {shlex.quote(container_dir)} && "
|
||||
f"echo {shlex.quote(b64)} | base64 -d > {shlex.quote(container_path)}"
|
||||
)
|
||||
|
||||
_cp = container_path # capture for closure
|
||||
|
||||
async def _write():
|
||||
proc = await self._sandbox.exec.aio("bash", "-c", cmd)
|
||||
await proc.wait.aio()
|
||||
|
||||
self._worker.run_coroutine(_write(), timeout=15)
|
||||
self._synced_creds[container_path] = file_key
|
||||
logger.debug("Modal: synced credential %s -> %s", host_path, container_path)
|
||||
except Exception as e:
|
||||
logger.debug("Modal: credential file sync failed: %s", e)
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
# Sync credential files before each command so mid-session
|
||||
# OAuth setups are picked up without requiring a restart.
|
||||
self._sync_credential_files()
|
||||
|
||||
if stdin_data is not None:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
|
||||
+17
-2
@@ -45,8 +45,23 @@ def clear_session_context() -> None:
|
||||
# ── Availability check ──
|
||||
|
||||
def _check_honcho_available() -> bool:
|
||||
"""Tool is only available when Honcho is active."""
|
||||
return _session_manager is not None and _session_key is not None
|
||||
"""Tool is available when Honcho is active OR configured.
|
||||
|
||||
At banner time the session context hasn't been injected yet, but if
|
||||
a valid config exists the tools *will* activate once the agent starts.
|
||||
Returning True for "configured" prevents the banner from marking
|
||||
honcho tools as red/disabled when they're actually going to work.
|
||||
"""
|
||||
# Fast path: session already active (mid-conversation)
|
||||
if _session_manager is not None and _session_key is not None:
|
||||
return True
|
||||
# Slow path: check if Honcho is configured (banner time)
|
||||
try:
|
||||
from honcho_integration.client import HonchoClientConfig
|
||||
cfg = HonchoClientConfig.from_global_config()
|
||||
return cfg.enabled and bool(cfg.api_key or cfg.base_url)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_session_context(**kwargs):
|
||||
|
||||
+141
-17
@@ -70,6 +70,7 @@ Thread safety:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
@@ -89,6 +90,8 @@ logger = logging.getLogger(__name__)
|
||||
_MCP_AVAILABLE = False
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
_MCP_SAMPLING_TYPES = False
|
||||
_MCP_NOTIFICATION_TYPES = False
|
||||
_MCP_MESSAGE_HANDLER_SUPPORTED = False
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
@@ -119,9 +122,39 @@ try:
|
||||
_MCP_SAMPLING_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP sampling types not available -- sampling disabled")
|
||||
# Notification types for dynamic tool discovery (tools/list_changed)
|
||||
try:
|
||||
from mcp.types import (
|
||||
ServerNotification,
|
||||
ToolListChangedNotification,
|
||||
PromptListChangedNotification,
|
||||
ResourceListChangedNotification,
|
||||
)
|
||||
_MCP_NOTIFICATION_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP notification types not available -- dynamic tool discovery disabled")
|
||||
except ImportError:
|
||||
logger.debug("mcp package not installed -- MCP tool support disabled")
|
||||
|
||||
|
||||
def _check_message_handler_support() -> bool:
|
||||
"""Check if ClientSession accepts ``message_handler`` kwarg.
|
||||
|
||||
Inspects the constructor signature for backward compatibility with older
|
||||
MCP SDK versions that don't support notification handlers.
|
||||
"""
|
||||
if not _MCP_AVAILABLE:
|
||||
return False
|
||||
try:
|
||||
return "message_handler" in inspect.signature(ClientSession).parameters
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
_MCP_MESSAGE_HANDLER_SUPPORTED = _check_message_handler_support()
|
||||
if _MCP_AVAILABLE and not _MCP_MESSAGE_HANDLER_SUPPORTED:
|
||||
logger.debug("MCP SDK does not support message_handler -- dynamic tool discovery disabled")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -697,7 +730,7 @@ class MCPServerTask:
|
||||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"_sampling", "_registered_tool_names", "_auth_type",
|
||||
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
@@ -713,11 +746,80 @@ class MCPServerTask:
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
self._registered_tool_names: list[str] = []
|
||||
self._auth_type: str = ""
|
||||
self._refresh_lock = asyncio.Lock()
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
return "url" in self._config
|
||||
|
||||
# ----- Dynamic tool discovery (notifications/tools/list_changed) -----
|
||||
|
||||
def _make_message_handler(self):
|
||||
"""Build a ``message_handler`` callback for ``ClientSession``.
|
||||
|
||||
Dispatches on notification type. Only ``ToolListChangedNotification``
|
||||
triggers a refresh; prompt and resource change notifications are
|
||||
logged as stubs for future work.
|
||||
"""
|
||||
async def _handler(message):
|
||||
try:
|
||||
if isinstance(message, Exception):
|
||||
logger.debug("MCP message handler (%s): exception: %s", self.name, message)
|
||||
return
|
||||
if _MCP_NOTIFICATION_TYPES and isinstance(message, ServerNotification):
|
||||
match message.root:
|
||||
case ToolListChangedNotification():
|
||||
logger.info(
|
||||
"MCP server '%s': received tools/list_changed notification",
|
||||
self.name,
|
||||
)
|
||||
await self._refresh_tools()
|
||||
case PromptListChangedNotification():
|
||||
logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name)
|
||||
case ResourceListChangedNotification():
|
||||
logger.debug("MCP server '%s': resources/list_changed (ignored)", self.name)
|
||||
case _:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Error in MCP message handler for '%s'", self.name)
|
||||
return _handler
|
||||
|
||||
async def _refresh_tools(self):
|
||||
"""Re-fetch tools from the server and update the registry.
|
||||
|
||||
Called when the server sends ``notifications/tools/list_changed``.
|
||||
The lock prevents overlapping refreshes from rapid-fire notifications.
|
||||
After the initial ``await`` (list_tools), all mutations are synchronous
|
||||
— atomic from the event loop's perspective.
|
||||
"""
|
||||
from tools.registry import registry
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
async with self._refresh_lock:
|
||||
# 1. Fetch current tool list from server
|
||||
tools_result = await self.session.list_tools()
|
||||
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
||||
|
||||
# 2. Remove old tools from hermes-* umbrella toolsets
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
ts["tools"] = [t for t in ts["tools"] if t not in self._registered_tool_names]
|
||||
|
||||
# 3. Deregister old tools from the central registry
|
||||
for prefixed_name in self._registered_tool_names:
|
||||
registry.deregister(prefixed_name)
|
||||
|
||||
# 4. Re-register with fresh tool list
|
||||
self._tools = new_mcp_tools
|
||||
self._registered_tool_names = _register_server_tools(
|
||||
self.name, self, self._config
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"MCP server '%s': dynamically refreshed %d tool(s)",
|
||||
self.name, len(self._registered_tool_names),
|
||||
)
|
||||
|
||||
async def _run_stdio(self, config: dict):
|
||||
"""Run the server using stdio transport."""
|
||||
command = config.get("command")
|
||||
@@ -738,6 +840,8 @@ class MCPServerTask:
|
||||
)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
|
||||
sampling_kwargs["message_handler"] = self._make_message_handler()
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
@@ -769,6 +873,8 @@ class MCPServerTask:
|
||||
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
|
||||
sampling_kwargs["message_handler"] = self._make_message_handler()
|
||||
|
||||
if _MCP_NEW_HTTP:
|
||||
# New API (mcp >= 1.24.0): build an explicit httpx.AsyncClient
|
||||
@@ -1522,24 +1628,19 @@ def _existing_tool_names() -> List[str]:
|
||||
return names
|
||||
|
||||
|
||||
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
"""Connect to a single MCP server, discover tools, and register them.
|
||||
def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> List[str]:
|
||||
"""Register tools from an already-connected server into the registry.
|
||||
|
||||
Also registers utility tools for MCP Resources and Prompts support
|
||||
(list_resources, read_resource, list_prompts, get_prompt).
|
||||
Handles include/exclude filtering, utility tools, toolset creation,
|
||||
and hermes-* umbrella toolset injection.
|
||||
|
||||
Returns list of registered tool names.
|
||||
Used by both initial discovery and dynamic refresh (list_changed).
|
||||
|
||||
Returns:
|
||||
List of registered prefixed tool names.
|
||||
"""
|
||||
from tools.registry import registry
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
server = await asyncio.wait_for(
|
||||
_connect_server(name, config),
|
||||
timeout=connect_timeout,
|
||||
)
|
||||
with _lock:
|
||||
_servers[name] = server
|
||||
from toolsets import create_custom_toolset, TOOLSETS
|
||||
|
||||
registered_names: List[str] = []
|
||||
toolset_name = f"mcp-{name}"
|
||||
@@ -1625,8 +1726,6 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
)
|
||||
registered_names.append(util_name)
|
||||
|
||||
server._registered_tool_names = list(registered_names)
|
||||
|
||||
# Create a custom toolset so these tools are discoverable
|
||||
if registered_names:
|
||||
create_custom_toolset(
|
||||
@@ -1634,6 +1733,31 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
description=f"MCP tools from {name} server",
|
||||
tools=registered_names,
|
||||
)
|
||||
# Inject into hermes-* umbrella toolsets for default behavior
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
for tool_name in registered_names:
|
||||
if tool_name not in ts["tools"]:
|
||||
ts["tools"].append(tool_name)
|
||||
|
||||
return registered_names
|
||||
|
||||
|
||||
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
"""Connect to a single MCP server, discover tools, and register them.
|
||||
|
||||
Returns list of registered tool names.
|
||||
"""
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
server = await asyncio.wait_for(
|
||||
_connect_server(name, config),
|
||||
timeout=connect_timeout,
|
||||
)
|
||||
with _lock:
|
||||
_servers[name] = server
|
||||
|
||||
registered_names = _register_server_tools(name, server, config)
|
||||
server._registered_tool_names = list(registered_names)
|
||||
|
||||
transport_type = "HTTP" if "url" in config else "stdio"
|
||||
logger.info(
|
||||
|
||||
+29
-1
@@ -87,6 +87,23 @@ class ToolRegistry:
|
||||
if check_fn and toolset not in self._toolset_checks:
|
||||
self._toolset_checks[toolset] = check_fn
|
||||
|
||||
def deregister(self, name: str) -> None:
|
||||
"""Remove a tool from the registry.
|
||||
|
||||
Also cleans up the toolset check if no other tools remain in the
|
||||
same toolset. Used by MCP dynamic tool discovery to nuke-and-repave
|
||||
when a server sends ``notifications/tools/list_changed``.
|
||||
"""
|
||||
entry = self._tools.pop(name, None)
|
||||
if entry is None:
|
||||
return
|
||||
# Drop the toolset check if this was the last tool in that toolset
|
||||
if entry.toolset in self._toolset_checks and not any(
|
||||
e.toolset == entry.toolset for e in self._tools.values()
|
||||
):
|
||||
self._toolset_checks.pop(entry.toolset, None)
|
||||
logger.debug("Deregistered tool: %s", name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Schema retrieval
|
||||
# ------------------------------------------------------------------
|
||||
@@ -115,7 +132,9 @@ class ToolRegistry:
|
||||
if not quiet:
|
||||
logger.debug("Tool %s unavailable (check failed)", name)
|
||||
continue
|
||||
result.append({"type": "function", "function": entry.schema})
|
||||
# Ensure schema always has a "name" field — use entry.name as fallback
|
||||
schema_with_name = {**entry.schema, "name": entry.name}
|
||||
result.append({"type": "function", "function": schema_with_name})
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -149,6 +168,15 @@ class ToolRegistry:
|
||||
"""Return sorted list of all registered tool names."""
|
||||
return sorted(self._tools.keys())
|
||||
|
||||
def get_schema(self, name: str) -> Optional[dict]:
|
||||
"""Return a tool's raw schema dict, bypassing check_fn filtering.
|
||||
|
||||
Useful for token estimation and introspection where availability
|
||||
doesn't matter — only the schema content does.
|
||||
"""
|
||||
entry = self._tools.get(name)
|
||||
return entry.schema if entry else None
|
||||
|
||||
def get_toolset_for_tool(self, name: str) -> Optional[str]:
|
||||
"""Return the toolset a tool belongs to, or None."""
|
||||
entry = self._tools.get(name)
|
||||
|
||||
@@ -343,6 +343,14 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
||||
result = await _send_email(pconfig.extra, chat_id, chunk)
|
||||
elif platform == Platform.SMS:
|
||||
result = await _send_sms(pconfig.api_key, chat_id, chunk)
|
||||
elif platform == Platform.MATTERMOST:
|
||||
result = await _send_mattermost(pconfig.token, pconfig.extra, chat_id, chunk)
|
||||
elif platform == Platform.MATRIX:
|
||||
result = await _send_matrix(pconfig.token, pconfig.extra, chat_id, chunk)
|
||||
elif platform == Platform.HOMEASSISTANT:
|
||||
result = await _send_homeassistant(pconfig.token, pconfig.extra, chat_id, chunk)
|
||||
elif platform == Platform.DINGTALK:
|
||||
result = await _send_dingtalk(pconfig.extra, chat_id, chunk)
|
||||
else:
|
||||
result = {"error": f"Direct sending not yet implemented for {platform.value}"}
|
||||
|
||||
@@ -666,6 +674,109 @@ async def _send_sms(auth_token, chat_id, message):
|
||||
return {"error": f"SMS send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_mattermost(token, extra, chat_id, message):
|
||||
"""Send via Mattermost REST API."""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
|
||||
try:
|
||||
base_url = (extra.get("url") or os.getenv("MATTERMOST_URL", "")).rstrip("/")
|
||||
token = token or os.getenv("MATTERMOST_TOKEN", "")
|
||||
if not base_url or not token:
|
||||
return {"error": "Mattermost not configured (MATTERMOST_URL, MATTERMOST_TOKEN required)"}
|
||||
url = f"{base_url}/api/v4/posts"
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session:
|
||||
async with session.post(url, headers=headers, json={"channel_id": chat_id, "message": message}) as resp:
|
||||
if resp.status not in (200, 201):
|
||||
body = await resp.text()
|
||||
return {"error": f"Mattermost API error ({resp.status}): {body}"}
|
||||
data = await resp.json()
|
||||
return {"success": True, "platform": "mattermost", "chat_id": chat_id, "message_id": data.get("id")}
|
||||
except Exception as e:
|
||||
return {"error": f"Mattermost send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_matrix(token, extra, chat_id, message):
|
||||
"""Send via Matrix Client-Server API."""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
|
||||
try:
|
||||
homeserver = (extra.get("homeserver") or os.getenv("MATRIX_HOMESERVER", "")).rstrip("/")
|
||||
token = token or os.getenv("MATRIX_ACCESS_TOKEN", "")
|
||||
if not homeserver or not token:
|
||||
return {"error": "Matrix not configured (MATRIX_HOMESERVER, MATRIX_ACCESS_TOKEN required)"}
|
||||
txn_id = f"hermes_{int(time.time() * 1000)}"
|
||||
url = f"{homeserver}/_matrix/client/v3/rooms/{chat_id}/send/m.room.message/{txn_id}"
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session:
|
||||
async with session.put(url, headers=headers, json={"msgtype": "m.text", "body": message}) as resp:
|
||||
if resp.status not in (200, 201):
|
||||
body = await resp.text()
|
||||
return {"error": f"Matrix API error ({resp.status}): {body}"}
|
||||
data = await resp.json()
|
||||
return {"success": True, "platform": "matrix", "chat_id": chat_id, "message_id": data.get("event_id")}
|
||||
except Exception as e:
|
||||
return {"error": f"Matrix send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_homeassistant(token, extra, chat_id, message):
|
||||
"""Send via Home Assistant notify service."""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
|
||||
try:
|
||||
hass_url = (extra.get("url") or os.getenv("HASS_URL", "")).rstrip("/")
|
||||
token = token or os.getenv("HASS_TOKEN", "")
|
||||
if not hass_url or not token:
|
||||
return {"error": "Home Assistant not configured (HASS_URL, HASS_TOKEN required)"}
|
||||
url = f"{hass_url}/api/services/notify/notify"
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session:
|
||||
async with session.post(url, headers=headers, json={"message": message, "target": chat_id}) as resp:
|
||||
if resp.status not in (200, 201):
|
||||
body = await resp.text()
|
||||
return {"error": f"Home Assistant API error ({resp.status}): {body}"}
|
||||
return {"success": True, "platform": "homeassistant", "chat_id": chat_id}
|
||||
except Exception as e:
|
||||
return {"error": f"Home Assistant send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_dingtalk(extra, chat_id, message):
|
||||
"""Send via DingTalk robot webhook.
|
||||
|
||||
Note: The gateway's DingTalk adapter uses per-session webhook URLs from
|
||||
incoming messages (dingtalk-stream SDK). For cross-platform send_message
|
||||
delivery we use a static robot webhook URL instead, which must be
|
||||
configured via ``DINGTALK_WEBHOOK_URL`` env var or ``webhook_url`` in the
|
||||
platform's extra config.
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
return {"error": "httpx not installed"}
|
||||
try:
|
||||
webhook_url = extra.get("webhook_url") or os.getenv("DINGTALK_WEBHOOK_URL", "")
|
||||
if not webhook_url:
|
||||
return {"error": "DingTalk not configured. Set DINGTALK_WEBHOOK_URL env var or webhook_url in dingtalk platform extra config."}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
webhook_url,
|
||||
json={"msgtype": "text", "text": {"content": message}},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if data.get("errcode", 0) != 0:
|
||||
return {"error": f"DingTalk API error: {data.get('errmsg', 'unknown')}"}
|
||||
return {"success": True, "platform": "dingtalk", "chat_id": chat_id}
|
||||
except Exception as e:
|
||||
return {"error": f"DingTalk send failed: {e}"}
|
||||
|
||||
|
||||
def _check_send_message():
|
||||
"""Gate send_message on gateway running (always available on messaging platforms)."""
|
||||
platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||
|
||||
+131
-67
@@ -494,7 +494,7 @@ def _is_skill_disabled(name: str, platform: str = None) -> bool:
|
||||
|
||||
|
||||
def _find_all_skills(*, skip_disabled: bool = False) -> List[Dict[str, Any]]:
|
||||
"""Recursively find all skills in ~/.hermes/skills/.
|
||||
"""Recursively find all skills in ~/.hermes/skills/ and external dirs.
|
||||
|
||||
Args:
|
||||
skip_disabled: If True, return ALL skills regardless of disabled
|
||||
@@ -504,59 +504,68 @@ def _find_all_skills(*, skip_disabled: bool = False) -> List[Dict[str, Any]]:
|
||||
Returns:
|
||||
List of skill metadata dicts (name, description, category).
|
||||
"""
|
||||
skills = []
|
||||
from agent.skill_utils import get_external_skills_dirs
|
||||
|
||||
if not SKILLS_DIR.exists():
|
||||
return skills
|
||||
skills = []
|
||||
seen_names: set = set()
|
||||
|
||||
# Load disabled set once (not per-skill)
|
||||
disabled = set() if skip_disabled else _get_disabled_skill_names()
|
||||
|
||||
# Scan local dir first, then external dirs (local takes precedence)
|
||||
dirs_to_scan = []
|
||||
if SKILLS_DIR.exists():
|
||||
dirs_to_scan.append(SKILLS_DIR)
|
||||
dirs_to_scan.extend(get_external_skills_dirs())
|
||||
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
if any(part in _EXCLUDED_SKILL_DIRS for part in skill_md.parts):
|
||||
continue
|
||||
|
||||
skill_dir = skill_md.parent
|
||||
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")[:4000]
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
|
||||
if not skill_matches_platform(frontmatter):
|
||||
for scan_dir in dirs_to_scan:
|
||||
for skill_md in scan_dir.rglob("SKILL.md"):
|
||||
if any(part in _EXCLUDED_SKILL_DIRS for part in skill_md.parts):
|
||||
continue
|
||||
|
||||
name = frontmatter.get("name", skill_dir.name)[:MAX_NAME_LENGTH]
|
||||
if name in disabled:
|
||||
skill_dir = skill_md.parent
|
||||
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")[:4000]
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
|
||||
name = frontmatter.get("name", skill_dir.name)[:MAX_NAME_LENGTH]
|
||||
if name in seen_names:
|
||||
continue
|
||||
if name in disabled:
|
||||
continue
|
||||
|
||||
description = frontmatter.get("description", "")
|
||||
if not description:
|
||||
for line in body.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
description = line
|
||||
break
|
||||
|
||||
if len(description) > MAX_DESCRIPTION_LENGTH:
|
||||
description = description[:MAX_DESCRIPTION_LENGTH - 3] + "..."
|
||||
|
||||
category = _get_category_from_path(skill_md)
|
||||
|
||||
seen_names.add(name)
|
||||
skills.append({
|
||||
"name": name,
|
||||
"description": description,
|
||||
"category": category,
|
||||
})
|
||||
|
||||
except (UnicodeDecodeError, PermissionError) as e:
|
||||
logger.debug("Failed to read skill file %s: %s", skill_md, e)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Skipping skill at %s: failed to parse: %s", skill_md, e, exc_info=True
|
||||
)
|
||||
continue
|
||||
|
||||
description = frontmatter.get("description", "")
|
||||
if not description:
|
||||
for line in body.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
description = line
|
||||
break
|
||||
|
||||
if len(description) > MAX_DESCRIPTION_LENGTH:
|
||||
description = description[:MAX_DESCRIPTION_LENGTH - 3] + "..."
|
||||
|
||||
category = _get_category_from_path(skill_md)
|
||||
|
||||
skills.append({
|
||||
"name": name,
|
||||
"description": description,
|
||||
"category": category,
|
||||
})
|
||||
|
||||
except (UnicodeDecodeError, PermissionError) as e:
|
||||
logger.debug("Failed to read skill file %s: %s", skill_md, e)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Skipping skill at %s: failed to parse: %s", skill_md, e, exc_info=True
|
||||
)
|
||||
continue
|
||||
|
||||
return skills
|
||||
|
||||
@@ -756,7 +765,15 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
||||
JSON string with skill content or error message
|
||||
"""
|
||||
try:
|
||||
if not SKILLS_DIR.exists():
|
||||
from agent.skill_utils import get_external_skills_dirs
|
||||
|
||||
# Build list of all skill directories to search
|
||||
all_dirs = []
|
||||
if SKILLS_DIR.exists():
|
||||
all_dirs.append(SKILLS_DIR)
|
||||
all_dirs.extend(get_external_skills_dirs())
|
||||
|
||||
if not all_dirs:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
@@ -768,27 +785,37 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
||||
skill_dir = None
|
||||
skill_md = None
|
||||
|
||||
# Try direct path first (e.g., "mlops/axolotl")
|
||||
direct_path = SKILLS_DIR / name
|
||||
if direct_path.is_dir() and (direct_path / "SKILL.md").exists():
|
||||
skill_dir = direct_path
|
||||
skill_md = direct_path / "SKILL.md"
|
||||
elif direct_path.with_suffix(".md").exists():
|
||||
skill_md = direct_path.with_suffix(".md")
|
||||
# Search all dirs: local first, then external (first match wins)
|
||||
for search_dir in all_dirs:
|
||||
# Try direct path first (e.g., "mlops/axolotl")
|
||||
direct_path = search_dir / name
|
||||
if direct_path.is_dir() and (direct_path / "SKILL.md").exists():
|
||||
skill_dir = direct_path
|
||||
skill_md = direct_path / "SKILL.md"
|
||||
break
|
||||
elif direct_path.with_suffix(".md").exists():
|
||||
skill_md = direct_path.with_suffix(".md")
|
||||
break
|
||||
|
||||
# Search by directory name
|
||||
# Search by directory name across all dirs
|
||||
if not skill_md:
|
||||
for found_skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
if found_skill_md.parent.name == name:
|
||||
skill_dir = found_skill_md.parent
|
||||
skill_md = found_skill_md
|
||||
for search_dir in all_dirs:
|
||||
for found_skill_md in search_dir.rglob("SKILL.md"):
|
||||
if found_skill_md.parent.name == name:
|
||||
skill_dir = found_skill_md.parent
|
||||
skill_md = found_skill_md
|
||||
break
|
||||
if skill_md:
|
||||
break
|
||||
|
||||
# Legacy: flat .md files
|
||||
if not skill_md:
|
||||
for found_md in SKILLS_DIR.rglob(f"{name}.md"):
|
||||
if found_md.name != "SKILL.md":
|
||||
skill_md = found_md
|
||||
for search_dir in all_dirs:
|
||||
for found_md in search_dir.rglob(f"{name}.md"):
|
||||
if found_md.name != "SKILL.md":
|
||||
skill_md = found_md
|
||||
break
|
||||
if skill_md:
|
||||
break
|
||||
|
||||
if not skill_md or not skill_md.exists():
|
||||
@@ -815,12 +842,21 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Security: warn if skill is loaded from outside the trusted skills directory
|
||||
# Security: warn if skill is loaded from outside trusted directories
|
||||
# (local skills dir + configured external_dirs are all trusted)
|
||||
_outside_skills_dir = True
|
||||
_trusted_dirs = [SKILLS_DIR.resolve()]
|
||||
try:
|
||||
skill_md.resolve().relative_to(SKILLS_DIR.resolve())
|
||||
_outside_skills_dir = False
|
||||
except ValueError:
|
||||
_outside_skills_dir = True
|
||||
_trusted_dirs.extend(d.resolve() for d in all_dirs[1:])
|
||||
except Exception:
|
||||
pass
|
||||
for _td in _trusted_dirs:
|
||||
try:
|
||||
skill_md.resolve().relative_to(_td)
|
||||
_outside_skills_dir = False
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Security: detect common prompt injection patterns
|
||||
_INJECTION_PATTERNS = [
|
||||
@@ -1058,7 +1094,11 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
||||
if script_files:
|
||||
linked_files["scripts"] = script_files
|
||||
|
||||
rel_path = str(skill_md.relative_to(SKILLS_DIR))
|
||||
try:
|
||||
rel_path = str(skill_md.relative_to(SKILLS_DIR))
|
||||
except ValueError:
|
||||
# External skill — use path relative to the skill's own parent dir
|
||||
rel_path = str(skill_md.relative_to(skill_md.parent.parent)) if skill_md.parent.parent else skill_md.name
|
||||
skill_name = frontmatter.get(
|
||||
"name", skill_md.stem if not skill_dir else skill_dir.name
|
||||
)
|
||||
@@ -1106,6 +1146,27 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Register credential files for mounting into remote sandboxes
|
||||
# (Modal, Docker). Files that exist on the host are registered;
|
||||
# missing ones are added to the setup_needed indicators.
|
||||
required_cred_files_raw = frontmatter.get("required_credential_files", [])
|
||||
if not isinstance(required_cred_files_raw, list):
|
||||
required_cred_files_raw = []
|
||||
missing_cred_files: list = []
|
||||
if required_cred_files_raw:
|
||||
try:
|
||||
from tools.credential_files import register_credential_files
|
||||
|
||||
missing_cred_files = register_credential_files(required_cred_files_raw)
|
||||
if missing_cred_files:
|
||||
setup_needed = True
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not register credential files for skill %s",
|
||||
skill_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"name": skill_name,
|
||||
@@ -1121,6 +1182,7 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
||||
"required_environment_variables": required_env_vars,
|
||||
"required_commands": [],
|
||||
"missing_required_environment_variables": remaining_missing_required_envs,
|
||||
"missing_credential_files": missing_cred_files,
|
||||
"missing_required_commands": [],
|
||||
"setup_needed": setup_needed,
|
||||
"setup_skipped": capture_result["setup_skipped"],
|
||||
@@ -1139,6 +1201,8 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
||||
if setup_needed:
|
||||
missing_items = [
|
||||
f"env ${env_name}" for env_name in remaining_missing_required_envs
|
||||
] + [
|
||||
f"file {path}" for path in missing_cred_files
|
||||
]
|
||||
setup_note = _build_setup_note(
|
||||
SkillReadinessStatus.SETUP_NEEDED,
|
||||
|
||||
@@ -48,6 +48,7 @@ logger = logging.getLogger(__name__)
|
||||
# long-running subprocesses immediately instead of blocking until timeout.
|
||||
# ---------------------------------------------------------------------------
|
||||
from tools.interrupt import is_interrupted, _interrupt_event # noqa: F401 — re-exported
|
||||
# display_hermes_home imported lazily at call site (stale-module safety during hermes update)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -157,7 +158,8 @@ def _handle_sudo_failure(output: str, env_type: str) -> str:
|
||||
|
||||
for failure in sudo_failures:
|
||||
if failure in output:
|
||||
return output + "\n\n💡 Tip: To enable sudo over messaging, add SUDO_PASSWORD to ~/.hermes/.env on the agent machine."
|
||||
from hermes_constants import display_hermes_home as _dhh
|
||||
return output + f"\n\n💡 Tip: To enable sudo over messaging, add SUDO_PASSWORD to {_dhh()}/.env on the agent machine."
|
||||
|
||||
return output
|
||||
|
||||
@@ -1283,7 +1285,8 @@ if __name__ == "__main__":
|
||||
print(f" TERMINAL_MODAL_IMAGE: {os.getenv('TERMINAL_MODAL_IMAGE', default_img)}")
|
||||
print(f" TERMINAL_DAYTONA_IMAGE: {os.getenv('TERMINAL_DAYTONA_IMAGE', default_img)}")
|
||||
print(f" TERMINAL_CWD: {os.getenv('TERMINAL_CWD', os.getcwd())}")
|
||||
print(f" TERMINAL_SANDBOX_DIR: {os.getenv('TERMINAL_SANDBOX_DIR', '~/.hermes/sandboxes')}")
|
||||
from hermes_constants import display_hermes_home as _dhh
|
||||
print(f" TERMINAL_SANDBOX_DIR: {os.getenv('TERMINAL_SANDBOX_DIR', f'{_dhh()}/sandboxes')}")
|
||||
print(f" TERMINAL_TIMEOUT: {os.getenv('TERMINAL_TIMEOUT', '60')}")
|
||||
print(f" TERMINAL_LIFETIME_SECONDS: {os.getenv('TERMINAL_LIFETIME_SECONDS', '300')}")
|
||||
|
||||
|
||||
+1
-1
@@ -832,7 +832,7 @@ TTS_SCHEMA = {
|
||||
},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "Optional custom file path to save the audio. Defaults to ~/.hermes/cache/audio/<timestamp>.mp3"
|
||||
"description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/<timestamp>.mp3"
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
|
||||
@@ -90,6 +90,7 @@ pytest tests/ -v
|
||||
- **Comments**: Only when explaining non-obvious intent, trade-offs, or API quirks
|
||||
- **Error handling**: Catch specific exceptions. Use `logger.warning()`/`logger.error()` with `exc_info=True` for unexpected errors
|
||||
- **Cross-platform**: Never assume Unix (see below)
|
||||
- **Profile-safe paths**: Never hardcode `~/.hermes` — use `get_hermes_home()` from `hermes_constants` for code paths and `display_hermes_home()` for user-facing messages. See [AGENTS.md](https://github.com/NousResearch/hermes-agent/blob/main/AGENTS.md#profiles-multi-instance-support) for full rules.
|
||||
|
||||
## Cross-Platform Compatibility
|
||||
|
||||
|
||||
@@ -168,11 +168,38 @@ required_environment_variables:
|
||||
The user can skip setup and keep loading the skill. Hermes never exposes the raw secret value to the model. Gateway and messaging sessions show local setup guidance instead of collecting secrets in-band.
|
||||
|
||||
:::tip Sandbox Passthrough
|
||||
When your skill is loaded, any declared `required_environment_variables` that are set are **automatically passed through** to `execute_code` and `terminal` sandboxes. Your skill's scripts can access `$TENOR_API_KEY` (or `os.environ["TENOR_API_KEY"]` in Python) without the user needing to configure anything extra. See [Environment Variable Passthrough](/docs/user-guide/security#environment-variable-passthrough) for details.
|
||||
When your skill is loaded, any declared `required_environment_variables` that are set are **automatically passed through** to `execute_code` and `terminal` sandboxes — including remote backends like Docker and Modal. Your skill's scripts can access `$TENOR_API_KEY` (or `os.environ["TENOR_API_KEY"]` in Python) without the user needing to configure anything extra. See [Environment Variable Passthrough](/docs/user-guide/security#environment-variable-passthrough) for details.
|
||||
:::
|
||||
|
||||
Legacy `prerequisites.env_vars` remains supported as a backward-compatible alias.
|
||||
|
||||
### Credential File Requirements (OAuth tokens, etc.)
|
||||
|
||||
Skills that use OAuth or file-based credentials can declare files that need to be mounted into remote sandboxes. This is for credentials stored as **files** (not env vars) — typically OAuth token files produced by a setup script.
|
||||
|
||||
```yaml
|
||||
required_credential_files:
|
||||
- path: google_token.json
|
||||
description: Google OAuth2 token (created by setup script)
|
||||
- path: google_client_secret.json
|
||||
description: Google OAuth2 client credentials
|
||||
```
|
||||
|
||||
Each entry supports:
|
||||
- `path` (required) — file path relative to `~/.hermes/`
|
||||
- `description` (optional) — explains what the file is and how it's created
|
||||
|
||||
When loaded, Hermes checks if these files exist. Missing files trigger `setup_needed`. Existing files are automatically:
|
||||
- **Mounted into Docker** containers as read-only bind mounts
|
||||
- **Synced into Modal** sandboxes (at creation + before each command, so mid-session OAuth works)
|
||||
- Available on **local** backend without any special handling
|
||||
|
||||
:::tip When to use which
|
||||
Use `required_environment_variables` for simple API keys and tokens (strings stored in `~/.hermes/.env`). Use `required_credential_files` for OAuth token files, client secrets, service account JSON, certificates, or any credential that's a file on disk.
|
||||
:::
|
||||
|
||||
See the `skills/productivity/google-workspace/SKILL.md` for a complete example using both.
|
||||
|
||||
## Skill Guidelines
|
||||
|
||||
### No External Dependencies
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user