Compare commits

..

10 Commits

Author SHA1 Message Date
teknium1
832a27c17e fix(custom-endpoint): verify /models and suggest working /v1 base URL 2026-03-15 20:08:03 -07:00
Teknium
a56937735e fix(telegram): escape chunk indicators in MarkdownV2 (#1478) 2026-03-15 19:27:15 -07:00
Teknium
7148534401 fix(gateway): make /status report live state and tokens (#1476) 2026-03-15 19:18:58 -07:00
Teknium
4e91b0240b fix(honcho): correct seed_ai_identity to use session.add_messages() (#1475)
The seed_ai_identity method was calling assistant_peer.add_message() which
doesn't exist on the Honcho SDK's Peer class. Fixed to use the correct
pattern: session.add_messages([peer.message(content)]), matching the
existing message sync code at line 294.

Discovered and fixed by Yuqi (Hermes Agent), Angello's AI companion.

Co-authored-by: Angello Picasso <angello.picasso@devsu.com>
2026-03-15 19:07:57 -07:00
Teknium
5e92a4ce5a fix: auto-reload MCP tools when mcp_servers config changes without restart (#1474)
Fixes #1036

After adding an MCP server to config.yaml, users had to restart Hermes
before the new tools became visible — even though /reload-mcp existed.

Add _check_config_mcp_changes() called from process_loop every 5s:
- stat() config.yaml for mtime changes (fast path, no YAML parse)
- On mtime change, parse and compare mcp_servers section
- If mcp_servers changed, auto-trigger _reload_mcp() and notify user
- Skip check while agent is running to avoid interrupting tool calls
- Throttled to CONFIG_WATCH_INTERVAL=5s to avoid busy-polling

/reload-mcp still works for manual force-reload.

Tests: 6 new tests in TestMCPConfigWatch, all passed

Co-authored-by: teyrebaz33 <hakanerten02@hotmail.com>
2026-03-15 19:03:34 -07:00
Teknium
471c663fdf fix(cli): silence tirith prefetch install warnings at startup (#1452) 2026-03-15 18:07:03 -07:00
Teknium
64d333204b Merge pull request #1242 from NousResearch/fix/file-tool-log-noise
fix: reduce file tool log noise
2026-03-15 11:11:18 -07:00
Teknium
c44af43840 Merge pull request #1401 from NousResearch/hermes/hermes-eca4a640
test: protect atomic temp cleanup on interrupts
2026-03-15 11:10:41 -07:00
teknium1
b117bbc125 test: cover atomic temp cleanup on interrupts
- add regression coverage for BaseException cleanup in atomic_json_write
- add dedicated atomic_yaml_write tests, including interrupt cleanup
- document why BaseException is intentional in both helpers
2026-03-14 22:31:51 -07:00
teknium1
b59da08730 fix: reduce file tool log noise
- treat git diff --cached --quiet rc=1 as an expected checkpoint state
  instead of logging it as an error
- downgrade expected write PermissionError/EROFS/EACCES failures out of
  error logging while keeping unexpected exceptions at error level
- add regression tests for both logging behaviors
2026-03-13 22:14:00 -07:00
29 changed files with 1010 additions and 949 deletions

61
cli.py
View File

@@ -3484,6 +3484,56 @@ class HermesCLI:
except Exception as e:
print(f" Error generating insights: {e}")
def _check_config_mcp_changes(self) -> None:
"""Detect mcp_servers changes in config.yaml and auto-reload MCP connections.
Called from process_loop every CONFIG_WATCH_INTERVAL seconds.
Compares config.yaml mtime + mcp_servers section against the last
known state. When a change is detected, triggers _reload_mcp() and
informs the user so they know the tool list has been refreshed.
"""
import time
import yaml as _yaml
CONFIG_WATCH_INTERVAL = 5.0 # seconds between config.yaml stat() calls
now = time.monotonic()
if now - self._last_config_check < CONFIG_WATCH_INTERVAL:
return
self._last_config_check = now
from hermes_cli.config import get_config_path as _get_config_path
cfg_path = _get_config_path()
if not cfg_path.exists():
return
try:
mtime = cfg_path.stat().st_mtime
except OSError:
return
if mtime == self._config_mtime:
return # File unchanged — fast path
# File changed — check whether mcp_servers section changed
self._config_mtime = mtime
try:
with open(cfg_path, encoding="utf-8") as f:
new_cfg = _yaml.safe_load(f) or {}
except Exception:
return
new_mcp = new_cfg.get("mcp_servers") or {}
if new_mcp == self._config_mcp_servers:
return # mcp_servers unchanged (some other section was edited)
self._config_mcp_servers = new_mcp
# Notify user and reload
print()
print("🔄 MCP server config changed — reloading connections...")
with self._busy_command(self._slow_command_status("/reload-mcp")):
self._reload_mcp()
def _reload_mcp(self):
"""Reload MCP servers: disconnect all, re-read config.yaml, reconnect.
@@ -4749,6 +4799,12 @@ class HermesCLI:
self._interrupt_queue = queue.Queue() # For messages typed while agent is running
self._should_exit = False
self._last_ctrl_c_time = 0 # Track double Ctrl+C for force exit
# Config file watcher — detect mcp_servers changes and auto-reload
from hermes_cli.config import get_config_path as _get_config_path
_cfg_path = _get_config_path()
self._config_mtime: float = _cfg_path.stat().st_mtime if _cfg_path.exists() else 0.0
self._config_mcp_servers: dict = self.config.get("mcp_servers") or {}
self._last_config_check: float = 0.0 # monotonic time of last check
# Clarify tool state: interactive question/answer with the user.
# When the agent calls the clarify tool, _clarify_state is set and
@@ -4797,7 +4853,7 @@ class HermesCLI:
# Ensure tirith security scanner is available (downloads if needed)
try:
from tools.tirith_security import ensure_installed
ensure_installed()
ensure_installed(log_failures=False)
except Exception:
pass # Non-fatal — fail-open at scan time if unavailable
@@ -5682,6 +5738,9 @@ class HermesCLI:
try:
user_input = self._pending_input.get(timeout=0.1)
except queue.Empty:
# Periodic config watcher — auto-reload MCP on mcp_servers change
if not self._agent_running:
self._check_config_mcp_changes()
continue
if not user_input:

View File

@@ -322,6 +322,14 @@ class TelegramAdapter(BasePlatformAdapter):
# Format and split message if needed
formatted = self.format_message(content)
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
if len(chunks) > 1:
# truncate_message appends a raw " (1/2)" suffix. Escape the
# MarkdownV2-special parentheses so Telegram doesn't reject the
# chunk and fall back to plain text.
chunks = [
re.sub(r" \((\d+)/(\d+)\)$", r" \\(\1/\2\\)", chunk)
for chunk in chunks
]
message_ids = []
thread_id = metadata.get("thread_id") if metadata else None

View File

@@ -305,7 +305,7 @@ class GatewayRunner:
# Ensure tirith security scanner is available (downloads if needed)
try:
from tools.tirith_security import ensure_installed
ensure_installed()
ensure_installed(log_failures=False)
except Exception:
pass # Non-fatal — fail-open at scan time if unavailable
@@ -1114,6 +1114,9 @@ class GatewayRunner:
# let the adapter-level batching/queueing logic absorb them.
_quick_key = build_session_key(source)
if _quick_key in self._running_agents:
if event.get_command() == "status":
return await self._handle_status_command(event)
if event.message_type == MessageType.PHOTO:
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
adapter = self.adapters.get(source.platform)
@@ -1822,6 +1825,8 @@ class GatewayRunner:
# Update session with actual prompt token count and model from the agent
self.session_store.update_session(
session_entry.session_key,
input_tokens=agent_result.get("input_tokens", 0),
output_tokens=agent_result.get("output_tokens", 0),
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
model=agent_result.get("model"),
)
@@ -4171,11 +4176,15 @@ class GatewayRunner:
# Return final response, or a message if something went wrong
final_response = result.get("final_response")
# Extract last actual prompt token count from the agent's compressor
# Extract actual token counts from the agent instance used for this run
_last_prompt_toks = 0
_input_toks = 0
_output_toks = 0
_agent = agent_holder[0]
if _agent and hasattr(_agent, "context_compressor"):
_last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0)
_input_toks = getattr(_agent, "session_prompt_tokens", 0)
_output_toks = getattr(_agent, "session_completion_tokens", 0)
_resolved_model = getattr(_agent, "model", None) if _agent else None
if not final_response:
@@ -4187,6 +4196,8 @@ class GatewayRunner:
"tools": tools_holder[0] or [],
"history_offset": len(agent_history),
"last_prompt_tokens": _last_prompt_toks,
"input_tokens": _input_toks,
"output_tokens": _output_toks,
"model": _resolved_model,
}
@@ -4250,6 +4261,8 @@ class GatewayRunner:
"tools": tools_holder[0] or [],
"history_offset": len(agent_history),
"last_prompt_tokens": _last_prompt_toks,
"input_tokens": _input_toks,
"output_tokens": _output_toks,
"model": _resolved_model,
"session_id": effective_session_id,
}

View File

@@ -1112,8 +1112,32 @@ def _model_flow_custom(config):
effective_key = api_key or current_key
from hermes_cli.models import probe_api_models
probe = probe_api_models(effective_key, effective_url)
if probe.get("used_fallback") and probe.get("resolved_base_url"):
print(
f"Warning: endpoint verification worked at {probe['resolved_base_url']}/models, "
f"not the exact URL you entered. Saving the working base URL instead."
)
effective_url = probe["resolved_base_url"]
if base_url:
base_url = effective_url
elif probe.get("models") is not None:
print(
f"Verified endpoint via {probe.get('probed_url')} "
f"({len(probe.get('models') or [])} model(s) visible)"
)
else:
print(
f"Warning: could not verify this endpoint via {probe.get('probed_url')}. "
f"Hermes will still save it."
)
if probe.get("suggested_base_url"):
print(f" If this server expects /v1, try base URL: {probe['suggested_base_url']}")
if base_url:
save_env_value("OPENAI_BASE_URL", base_url)
save_env_value("OPENAI_BASE_URL", effective_url)
if api_key:
save_env_value("OPENAI_API_KEY", api_key)

View File

@@ -308,6 +308,62 @@ def _fetch_anthropic_models(timeout: float = 5.0) -> Optional[list[str]]:
return None
def probe_api_models(
api_key: Optional[str],
base_url: Optional[str],
timeout: float = 5.0,
) -> dict[str, Any]:
"""Probe an OpenAI-compatible ``/models`` endpoint with light URL heuristics."""
normalized = (base_url or "").strip().rstrip("/")
if not normalized:
return {
"models": None,
"probed_url": None,
"resolved_base_url": "",
"suggested_base_url": None,
"used_fallback": False,
}
if normalized.endswith("/v1"):
alternate_base = normalized[:-3].rstrip("/")
else:
alternate_base = normalized + "/v1"
candidates: list[tuple[str, bool]] = [(normalized, False)]
if alternate_base and alternate_base != normalized:
candidates.append((alternate_base, True))
tried: list[str] = []
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
for candidate_base, is_fallback in candidates:
url = candidate_base.rstrip("/") + "/models"
tried.append(url)
req = urllib.request.Request(url, headers=headers)
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data = json.loads(resp.read().decode())
return {
"models": [m.get("id", "") for m in data.get("data", [])],
"probed_url": url,
"resolved_base_url": candidate_base.rstrip("/"),
"suggested_base_url": alternate_base if alternate_base != candidate_base else normalized,
"used_fallback": is_fallback,
}
except Exception:
continue
return {
"models": None,
"probed_url": tried[-1] if tried else normalized.rstrip("/") + "/models",
"resolved_base_url": normalized,
"suggested_base_url": alternate_base if alternate_base != normalized else None,
"used_fallback": False,
}
def fetch_api_models(
api_key: Optional[str],
base_url: Optional[str],
@@ -318,22 +374,7 @@ def fetch_api_models(
Returns a list of model ID strings, or ``None`` if the endpoint could not
be reached (network error, timeout, auth failure, etc.).
"""
if not base_url:
return None
url = base_url.rstrip("/") + "/models"
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
req = urllib.request.Request(url, headers=headers)
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data = json.loads(resp.read().decode())
# Standard OpenAI format: {"data": [{"id": "model-name", ...}, ...]}
return [m.get("id", "") for m in data.get("data", [])]
except Exception:
return None
return probe_api_models(api_key, base_url, timeout=timeout).get("models")
def validate_requested_model(
@@ -376,13 +417,53 @@ def validate_requested_model(
"message": "Model names cannot contain spaces.",
}
# Custom endpoints can serve any model — skip validation
if normalized == "custom":
probe = probe_api_models(api_key, base_url)
api_models = probe.get("models")
if api_models is not None:
if requested in set(api_models):
return {
"accepted": True,
"persist": True,
"recognized": True,
"message": None,
}
suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5)
suggestion_text = ""
if suggestions:
suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions)
message = (
f"Note: `{requested}` was not found in this custom endpoint's model listing "
f"({probe.get('probed_url')}). It may still work if the server supports hidden or aliased models."
f"{suggestion_text}"
)
if probe.get("used_fallback"):
message += (
f"\n Endpoint verification succeeded after trying `{probe.get('resolved_base_url')}`. "
f"Consider saving that as your base URL."
)
return {
"accepted": True,
"persist": True,
"recognized": False,
"message": message,
}
message = (
f"Note: could not reach this custom endpoint's model listing at `{probe.get('probed_url')}`. "
f"Hermes will still save `{requested}`, but the endpoint should expose `/models` for verification."
)
if probe.get("suggested_base_url"):
message += f"\n If this server expects `/v1`, try base URL: `{probe.get('suggested_base_url')}`"
return {
"accepted": True,
"persist": True,
"recognized": False,
"message": None,
"message": message,
}
# Probe the live API to check if the model actually exists

View File

@@ -933,11 +933,35 @@ def setup_model_provider(config: dict):
base_url = prompt(
" API base URL (e.g., https://api.example.com/v1)", current_url
)
).strip()
api_key = prompt(" API key", password=True)
model_name = prompt(" Model name (e.g., gpt-4, claude-3-opus)", current_model)
if base_url:
from hermes_cli.models import probe_api_models
probe = probe_api_models(api_key, base_url)
if probe.get("used_fallback") and probe.get("resolved_base_url"):
print_warning(
f"Endpoint verification worked at {probe['resolved_base_url']}/models, "
f"not the exact URL you entered. Saving the working base URL instead."
)
base_url = probe["resolved_base_url"]
elif probe.get("models") is not None:
print_success(
f"Verified endpoint via {probe.get('probed_url')} "
f"({len(probe.get('models') or [])} model(s) visible)"
)
else:
print_warning(
f"Could not verify this endpoint via {probe.get('probed_url')}. "
f"Hermes will still save it."
)
if probe.get("suggested_base_url"):
print_info(
f" If this server expects /v1, try base URL: {probe['suggested_base_url']}"
)
save_env_value("OPENAI_BASE_URL", base_url)
if api_key:
save_env_value("OPENAI_API_KEY", api_key)

View File

@@ -927,6 +927,11 @@ class HonchoSessionManager:
return False
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
honcho_session = self._sessions_cache.get(session.honcho_session_id)
if not honcho_session:
logger.warning("No Honcho session cached for '%s', skipping AI seed", session_key)
return False
try:
wrapped = (
f"<ai_identity_seed>\n"
@@ -935,7 +940,7 @@ class HonchoSessionManager:
f"{content.strip()}\n"
f"</ai_identity_seed>"
)
assistant_peer.add_message("assistant", wrapped)
honcho_session.add_messages([assistant_peer.message(wrapped)])
logger.info("Seeded AI identity from '%s' into %s", source, session_key)
return True
except Exception as e:

View File

@@ -0,0 +1,133 @@
"""Tests for gateway /status behavior and token persistence."""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionEntry, SessionSource, build_session_key
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
)
def _make_event(text: str) -> MessageEvent:
return MessageEvent(
text=text,
source=_make_source(),
message_id="m1",
)
def _make_runner(session_entry: SessionEntry):
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
adapter = MagicMock()
adapter.send = AsyncMock()
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = session_entry
runner.session_store.load_transcript.return_value = []
runner.session_store.has_any_sessions.return_value = True
runner.session_store.append_to_transcript = MagicMock()
runner.session_store.rewrite_transcript = MagicMock()
runner.session_store.update_session = MagicMock()
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
runner._show_reasoning = False
runner._is_user_authorized = lambda _source: True
runner._set_session_env = lambda _context: None
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
runner._send_voice_reply = AsyncMock()
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
runner._emit_gateway_run_progress = AsyncMock()
return runner
@pytest.mark.asyncio
async def test_status_command_reports_running_agent_without_interrupt(monkeypatch):
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
total_tokens=321,
)
runner = _make_runner(session_entry)
running_agent = MagicMock()
runner._running_agents[build_session_key(_make_source())] = running_agent
result = await runner._handle_message(_make_event("/status"))
assert "**Tokens:** 321" in result
assert "**Agent Running:** Yes ⚡" in result
running_agent.interrupt.assert_not_called()
assert runner._pending_messages == {}
@pytest.mark.asyncio
async def test_handle_message_persists_agent_token_counts(monkeypatch):
import gateway.run as gateway_run
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner = _make_runner(session_entry)
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
runner._run_agent = AsyncMock(
return_value={
"final_response": "ok",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 80,
"input_tokens": 120,
"output_tokens": 45,
"model": "openai/test-model",
}
)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100000,
)
result = await runner._handle_message(_make_event("hello"))
assert result == "ok"
runner.session_store.update_session.assert_called_once_with(
session_entry.session_key,
input_tokens=120,
output_tokens=45,
last_prompt_tokens=80,
model="openai/test-model",
)

View File

@@ -7,7 +7,7 @@ or corrupt user-visible content.
import re
import sys
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import pytest
@@ -392,3 +392,27 @@ class TestStripMdv2:
def test_empty_string(self):
assert _strip_mdv2("") == ""
@pytest.mark.asyncio
async def test_send_escapes_chunk_indicator_for_markdownv2(adapter):
adapter.MAX_MESSAGE_LENGTH = 80
adapter._bot = MagicMock()
sent_texts = []
async def _fake_send_message(**kwargs):
sent_texts.append(kwargs["text"])
msg = MagicMock()
msg.message_id = len(sent_texts)
return msg
adapter._bot.send_message = AsyncMock(side_effect=_fake_send_message)
content = ("**bold** chunk content " * 12).strip()
result = await adapter.send("123", content)
assert result.success is True
assert len(sent_texts) > 1
assert re.search(r" \\\([0-9]+/[0-9]+\\\)$", sent_texts[0])
assert re.search(r" \\\([0-9]+/[0-9]+\\\)$", sent_texts[-1])

View File

@@ -7,6 +7,7 @@ from hermes_cli.models import (
fetch_api_models,
normalize_provider,
parse_model_input,
probe_api_models,
provider_label,
provider_model_ids,
validate_requested_model,
@@ -26,7 +27,15 @@ FAKE_API_MODELS = [
def _validate(model, provider="openrouter", api_models=FAKE_API_MODELS, **kw):
"""Shortcut: call validate_requested_model with mocked API."""
with patch("hermes_cli.models.fetch_api_models", return_value=api_models):
probe_payload = {
"models": api_models,
"probed_url": "http://localhost:11434/v1/models",
"resolved_base_url": kw.get("base_url", "") or "http://localhost:11434/v1",
"suggested_base_url": None,
"used_fallback": False,
}
with patch("hermes_cli.models.fetch_api_models", return_value=api_models), \
patch("hermes_cli.models.probe_api_models", return_value=probe_payload):
return validate_requested_model(model, provider, **kw)
@@ -147,6 +156,33 @@ class TestFetchApiModels:
with patch("hermes_cli.models.urllib.request.urlopen", side_effect=Exception("timeout")):
assert fetch_api_models("key", "https://example.com/v1") is None
def test_probe_api_models_tries_v1_fallback(self):
class _Resp:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def read(self):
return b'{"data": [{"id": "local-model"}]}'
calls = []
def _fake_urlopen(req, timeout=5.0):
calls.append(req.full_url)
if req.full_url.endswith("/v1/models"):
return _Resp()
raise Exception("404")
with patch("hermes_cli.models.urllib.request.urlopen", side_effect=_fake_urlopen):
probe = probe_api_models("key", "http://localhost:8000")
assert calls == ["http://localhost:8000/models", "http://localhost:8000/v1/models"]
assert probe["models"] == ["local-model"]
assert probe["resolved_base_url"] == "http://localhost:8000/v1"
assert probe["used_fallback"] is True
# -- validate — format checks -----------------------------------------------
@@ -191,6 +227,7 @@ class TestValidateApiFound:
)
assert result["accepted"] is True
assert result["persist"] is True
assert result["recognized"] is True
# -- validate — API not found ------------------------------------------------
@@ -232,3 +269,26 @@ class TestValidateApiFallback:
result = _validate("some-model", provider="totally-unknown", api_models=None)
assert result["accepted"] is True
assert result["persist"] is True
def test_custom_endpoint_warns_with_probed_url_and_v1_hint(self):
with patch(
"hermes_cli.models.probe_api_models",
return_value={
"models": None,
"probed_url": "http://localhost:8000/v1/models",
"resolved_base_url": "http://localhost:8000",
"suggested_base_url": "http://localhost:8000/v1",
"used_fallback": False,
},
):
result = validate_requested_model(
"qwen3",
"custom",
api_key="local-key",
base_url="http://localhost:8000",
)
assert result["accepted"] is True
assert result["persist"] is True
assert "http://localhost:8000/v1/models" in result["message"]
assert "http://localhost:8000/v1" in result["message"]

View File

@@ -75,6 +75,58 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m
assert calls["count"] == 1
def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
_clear_provider_env(monkeypatch)
config = load_config()
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
return 3 # Custom endpoint
if question == "Configure vision:":
return len(choices) - 1 # Skip
raise AssertionError(f"Unexpected prompt_choice call: {question}")
def fake_prompt(message, current=None, **kwargs):
if "API base URL" in message:
return "http://localhost:8000"
if "API key" in message:
return "local-key"
if "Model name" in message:
return "llm"
return ""
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt)
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
monkeypatch.setattr(
"hermes_cli.models.probe_api_models",
lambda api_key, base_url: {
"models": ["llm"],
"probed_url": "http://localhost:8000/v1/models",
"resolved_base_url": "http://localhost:8000/v1",
"suggested_base_url": "http://localhost:8000/v1",
"used_fallback": True,
},
)
setup_model_provider(config)
save_config(config)
env = _read_env(tmp_path)
reloaded = load_config()
assert env.get("OPENAI_BASE_URL") == "http://localhost:8000/v1"
assert env.get("OPENAI_API_KEY") == "local-key"
assert reloaded["model"]["provider"] == "custom"
assert reloaded["model"]["base_url"] == "http://localhost:8000/v1"
assert reloaded["model"]["default"] == "llm"
def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tmp_path, monkeypatch):
"""Keep-current should respect config-backed providers, not fall back to OpenRouter."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))

View File

@@ -68,6 +68,22 @@ class TestAtomicJsonWrite:
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
assert len(tmp_files) == 0
def test_cleans_up_temp_file_on_baseexception(self, tmp_path):
class SimulatedAbort(BaseException):
pass
target = tmp_path / "data.json"
original = {"preserved": True}
target.write_text(json.dumps(original), encoding="utf-8")
with patch("utils.json.dump", side_effect=SimulatedAbort):
with pytest.raises(SimulatedAbort):
atomic_json_write(target, {"new": True})
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
assert len(tmp_files) == 0
assert json.loads(target.read_text(encoding="utf-8")) == original
def test_accepts_string_path(self, tmp_path):
target = str(tmp_path / "string_path.json")
atomic_json_write(target, {"string": True})

View File

@@ -0,0 +1,44 @@
"""Tests for utils.atomic_yaml_write — crash-safe YAML file writes."""
from pathlib import Path
from unittest.mock import patch
import pytest
import yaml
from utils import atomic_yaml_write
class TestAtomicYamlWrite:
def test_writes_valid_yaml(self, tmp_path):
target = tmp_path / "data.yaml"
data = {"key": "value", "nested": {"a": 1}}
atomic_yaml_write(target, data)
assert yaml.safe_load(target.read_text(encoding="utf-8")) == data
def test_cleans_up_temp_file_on_baseexception(self, tmp_path):
class SimulatedAbort(BaseException):
pass
target = tmp_path / "data.yaml"
original = {"preserved": True}
target.write_text(yaml.safe_dump(original), encoding="utf-8")
with patch("utils.yaml.dump", side_effect=SimulatedAbort):
with pytest.raises(SimulatedAbort):
atomic_yaml_write(target, {"new": True})
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
assert len(tmp_files) == 0
assert yaml.safe_load(target.read_text(encoding="utf-8")) == original
def test_appends_extra_content(self, tmp_path):
target = tmp_path / "data.yaml"
atomic_yaml_write(target, {"key": "value"}, extra_content="\n# comment\n")
text = target.read_text(encoding="utf-8")
assert "key: value" in text
assert "# comment" in text

View File

@@ -0,0 +1,103 @@
"""Tests for automatic MCP reload when config.yaml mcp_servers section changes."""
import time
from pathlib import Path
from unittest.mock import MagicMock, patch
def _make_cli(tmp_path, mcp_servers=None):
"""Create a minimal HermesCLI instance with mocked config."""
import cli as cli_mod
obj = object.__new__(cli_mod.HermesCLI)
obj.config = {"mcp_servers": mcp_servers or {}}
obj._agent_running = False
obj._last_config_check = 0.0
obj._config_mcp_servers = mcp_servers or {}
cfg_file = tmp_path / "config.yaml"
cfg_file.write_text("mcp_servers: {}\n")
obj._config_mtime = cfg_file.stat().st_mtime
obj._reload_mcp = MagicMock()
obj._busy_command = MagicMock()
obj._busy_command.return_value.__enter__ = MagicMock(return_value=None)
obj._busy_command.return_value.__exit__ = MagicMock(return_value=False)
obj._slow_command_status = MagicMock(return_value="reloading...")
return obj, cfg_file
class TestMCPConfigWatch:
def test_no_change_does_not_reload(self, tmp_path):
"""If mtime and mcp_servers unchanged, _reload_mcp is NOT called."""
obj, cfg_file = _make_cli(tmp_path)
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
obj._check_config_mcp_changes()
obj._reload_mcp.assert_not_called()
def test_mtime_change_with_same_mcp_servers_does_not_reload(self, tmp_path):
"""If file mtime changes but mcp_servers is identical, no reload."""
import yaml
obj, cfg_file = _make_cli(tmp_path, mcp_servers={"fs": {"command": "npx"}})
# Write same mcp_servers but touch the file
cfg_file.write_text(yaml.dump({"mcp_servers": {"fs": {"command": "npx"}}}))
# Force mtime to appear changed
obj._config_mtime = 0.0
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
obj._check_config_mcp_changes()
obj._reload_mcp.assert_not_called()
def test_new_mcp_server_triggers_reload(self, tmp_path):
"""Adding a new MCP server to config triggers auto-reload."""
import yaml
obj, cfg_file = _make_cli(tmp_path, mcp_servers={})
# Simulate user adding a new MCP server to config.yaml
cfg_file.write_text(yaml.dump({"mcp_servers": {"github": {"url": "https://mcp.github.com"}}}))
obj._config_mtime = 0.0 # force stale mtime
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
obj._check_config_mcp_changes()
obj._reload_mcp.assert_called_once()
def test_removed_mcp_server_triggers_reload(self, tmp_path):
"""Removing an MCP server from config triggers auto-reload."""
import yaml
obj, cfg_file = _make_cli(tmp_path, mcp_servers={"github": {"url": "https://mcp.github.com"}})
# Simulate user removing the server
cfg_file.write_text(yaml.dump({"mcp_servers": {}}))
obj._config_mtime = 0.0
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
obj._check_config_mcp_changes()
obj._reload_mcp.assert_called_once()
def test_interval_throttle_skips_check(self, tmp_path):
"""If called within CONFIG_WATCH_INTERVAL, stat() is skipped."""
obj, cfg_file = _make_cli(tmp_path)
obj._last_config_check = time.monotonic() # just checked
with patch("hermes_cli.config.get_config_path", return_value=cfg_file), \
patch.object(Path, "stat") as mock_stat:
obj._check_config_mcp_changes()
mock_stat.assert_not_called()
obj._reload_mcp.assert_not_called()
def test_missing_config_file_does_not_crash(self, tmp_path):
"""If config.yaml doesn't exist, _check_config_mcp_changes is a no-op."""
obj, cfg_file = _make_cli(tmp_path)
missing = tmp_path / "nonexistent.yaml"
with patch("hermes_cli.config.get_config_path", return_value=missing):
obj._check_config_mcp_changes() # should not raise
obj._reload_mcp.assert_not_called()

View File

@@ -336,4 +336,42 @@ def test_cmd_model_falls_back_to_auto_on_invalid_provider(monkeypatch, capsys):
assert "Warning:" in output
assert "falling back to auto provider detection" in output.lower()
assert "No change." in output
assert "No change." in output
def test_model_flow_custom_saves_verified_v1_base_url(monkeypatch, capsys):
monkeypatch.setattr(
"hermes_cli.config.get_env_value",
lambda key: "" if key in {"OPENAI_BASE_URL", "OPENAI_API_KEY"} else "",
)
saved_env = {}
monkeypatch.setattr("hermes_cli.config.save_env_value", lambda key, value: saved_env.__setitem__(key, value))
monkeypatch.setattr("hermes_cli.auth._save_model_choice", lambda model: saved_env.__setitem__("MODEL", model))
monkeypatch.setattr("hermes_cli.auth.deactivate_provider", lambda: None)
monkeypatch.setattr("hermes_cli.main._save_custom_provider", lambda *args, **kwargs: None)
monkeypatch.setattr(
"hermes_cli.models.probe_api_models",
lambda api_key, base_url: {
"models": ["llm"],
"probed_url": "http://localhost:8000/v1/models",
"resolved_base_url": "http://localhost:8000/v1",
"suggested_base_url": "http://localhost:8000/v1",
"used_fallback": True,
},
)
monkeypatch.setattr(
"hermes_cli.config.load_config",
lambda: {"model": {"default": "", "provider": "custom", "base_url": ""}},
)
monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None)
answers = iter(["http://localhost:8000", "local-key", "llm"])
monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers))
hermes_main._model_flow_custom({})
output = capsys.readouterr().out
assert "Saving the working base URL instead" in output
assert saved_env["OPENAI_BASE_URL"] == "http://localhost:8000/v1"
assert saved_env["OPENAI_API_KEY"] == "local-key"
assert saved_env["MODEL"] == "llm"

View File

@@ -1,8 +1,10 @@
"""Tests for tools/checkpoint_manager.py — CheckpointManager."""
import logging
import os
import json
import shutil
import subprocess
import pytest
from pathlib import Path
from unittest.mock import patch
@@ -143,6 +145,12 @@ class TestTakeCheckpoint:
result = mgr.ensure_checkpoint(str(work_dir), "initial")
assert result is True
def test_successful_checkpoint_does_not_log_expected_diff_exit(self, mgr, work_dir, caplog):
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
result = mgr.ensure_checkpoint(str(work_dir), "initial")
assert result is True
assert not any("diff --cached --quiet" in r.getMessage() for r in caplog.records)
def test_dedup_same_turn(self, mgr, work_dir):
r1 = mgr.ensure_checkpoint(str(work_dir), "first")
r2 = mgr.ensure_checkpoint(str(work_dir), "second")
@@ -375,6 +383,26 @@ class TestErrorResilience:
result = mgr.ensure_checkpoint(str(work_dir), "test")
assert result is False
def test_run_git_allows_expected_nonzero_without_error_log(self, tmp_path, caplog):
completed = subprocess.CompletedProcess(
args=["git", "diff", "--cached", "--quiet"],
returncode=1,
stdout="",
stderr="",
)
with patch("tools.checkpoint_manager.subprocess.run", return_value=completed):
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
ok, stdout, stderr = _run_git(
["diff", "--cached", "--quiet"],
tmp_path / "shadow",
str(tmp_path / "work"),
allowed_returncodes={1},
)
assert ok is False
assert stdout == ""
assert stderr == ""
assert not caplog.records
def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch):
"""Checkpoint failures should never raise — they're silently logged."""
def broken_run_git(*args, **kwargs):

View File

@@ -5,6 +5,7 @@ handling without requiring a running terminal environment.
"""
import json
import logging
from unittest.mock import MagicMock, patch
from tools.file_tools import (
@@ -87,13 +88,26 @@ class TestWriteFileHandler:
mock_ops.write_file.assert_called_once_with("/tmp/out.txt", "hello world!\n")
@patch("tools.file_tools._get_file_ops")
def test_exception_returns_error_json(self, mock_get):
def test_permission_error_returns_error_json_without_error_log(self, mock_get, caplog):
mock_get.side_effect = PermissionError("read-only filesystem")
from tools.file_tools import write_file_tool
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
with caplog.at_level(logging.DEBUG, logger="tools.file_tools"):
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
assert "error" in result
assert "read-only" in result["error"]
assert any("write_file expected denial" in r.getMessage() for r in caplog.records)
assert not any(r.levelno >= logging.ERROR for r in caplog.records)
@patch("tools.file_tools._get_file_ops")
def test_unexpected_exception_still_logs_error(self, mock_get, caplog):
mock_get.side_effect = RuntimeError("boom")
from tools.file_tools import write_file_tool
with caplog.at_level(logging.ERROR, logger="tools.file_tools"):
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
assert result["error"] == "boom"
assert any("write_file error" in r.getMessage() for r in caplog.records)
class TestPatchHandler:

View File

@@ -26,7 +26,8 @@ def _make_fake_popen(captured: dict):
proc = MagicMock()
proc.poll.return_value = 0
proc.returncode = 0
proc.stdout = MagicMock(__iter__=lambda s: iter([]), __next__=lambda s: (_ for _ in ()).throw(StopIteration))
proc.stdout = iter([])
proc.stdout.close = lambda: None
proc.stdin = MagicMock()
return proc
return fake_popen

View File

@@ -1,152 +0,0 @@
"""Tests for the local persistent shell backend."""
import glob as glob_mod
import pytest
from tools.environments.local import LocalEnvironment
from tools.environments.persistent_shell import PersistentShellMixin
class TestLocalConfig:
def test_local_persistent_default_false(self, monkeypatch):
monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False)
from tools.terminal_tool import _get_env_config
assert _get_env_config()["local_persistent"] is False
def test_local_persistent_true(self, monkeypatch):
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true")
from tools.terminal_tool import _get_env_config
assert _get_env_config()["local_persistent"] is True
def test_local_persistent_yes(self, monkeypatch):
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes")
from tools.terminal_tool import _get_env_config
assert _get_env_config()["local_persistent"] is True
class TestMergeOutput:
def test_stdout_only(self):
assert PersistentShellMixin._merge_output("out", "") == "out"
def test_stderr_only(self):
assert PersistentShellMixin._merge_output("", "err") == "err"
def test_both(self):
assert PersistentShellMixin._merge_output("out", "err") == "out\nerr"
def test_empty(self):
assert PersistentShellMixin._merge_output("", "") == ""
def test_strips_trailing_newlines(self):
assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr"
class TestLocalOneShotRegression:
def test_echo(self):
env = LocalEnvironment(persistent=False)
r = env.execute("echo hello")
assert r["returncode"] == 0
assert "hello" in r["output"]
env.cleanup()
def test_exit_code(self):
env = LocalEnvironment(persistent=False)
r = env.execute("exit 42")
assert r["returncode"] == 42
env.cleanup()
def test_state_does_not_persist(self):
env = LocalEnvironment(persistent=False)
env.execute("export HERMES_ONESHOT_LOCAL=yes")
r = env.execute("echo $HERMES_ONESHOT_LOCAL")
assert r["output"].strip() == ""
env.cleanup()
class TestLocalPersistent:
@pytest.fixture
def env(self):
e = LocalEnvironment(persistent=True)
yield e
e.cleanup()
def test_echo(self, env):
r = env.execute("echo hello-persistent")
assert r["returncode"] == 0
assert "hello-persistent" in r["output"]
def test_env_var_persists(self, env):
env.execute("export HERMES_LOCAL_PERSIST_TEST=works")
r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST")
assert r["output"].strip() == "works"
def test_cwd_persists(self, env):
env.execute("cd /tmp")
r = env.execute("pwd")
assert r["output"].strip() == "/tmp"
def test_exit_code(self, env):
r = env.execute("(exit 42)")
assert r["returncode"] == 42
def test_stderr(self, env):
r = env.execute("echo oops >&2")
assert r["returncode"] == 0
assert "oops" in r["output"]
def test_multiline_output(self, env):
r = env.execute("echo a; echo b; echo c")
lines = r["output"].strip().splitlines()
assert lines == ["a", "b", "c"]
def test_timeout_then_recovery(self, env):
r = env.execute("sleep 999", timeout=2)
assert r["returncode"] in (124, 130)
r = env.execute("echo alive")
assert r["returncode"] == 0
assert "alive" in r["output"]
def test_large_output(self, env):
r = env.execute("seq 1 1000")
assert r["returncode"] == 0
lines = r["output"].strip().splitlines()
assert len(lines) == 1000
assert lines[0] == "1"
assert lines[-1] == "1000"
def test_shell_variable_persists(self, env):
env.execute("MY_LOCAL_VAR=hello123")
r = env.execute("echo $MY_LOCAL_VAR")
assert r["output"].strip() == "hello123"
def test_cleanup_removes_temp_files(self, env):
env.execute("echo warmup")
prefix = env._temp_prefix
assert len(glob_mod.glob(f"{prefix}-*")) > 0
env.cleanup()
remaining = glob_mod.glob(f"{prefix}-*")
assert remaining == []
def test_state_does_not_leak_between_instances(self):
env1 = LocalEnvironment(persistent=True)
env2 = LocalEnvironment(persistent=True)
try:
env1.execute("export LEAK_TEST=from_env1")
r = env2.execute("echo $LEAK_TEST")
assert r["output"].strip() == ""
finally:
env1.cleanup()
env2.cleanup()
def test_special_characters_in_command(self, env):
r = env.execute("echo 'hello world'")
assert r["output"].strip() == "hello world"
def test_pipe_command(self, env):
r = env.execute("echo hello | tr 'h' 'H'")
assert r["output"].strip() == "Hello"
def test_multiple_commands_semicolon(self, env):
r = env.execute("X=42; echo $X")
assert r["output"].strip() == "42"

View File

@@ -1,167 +0,0 @@
"""Tests for the SSH remote execution environment backend."""
import json
import os
import subprocess
from unittest.mock import MagicMock
import pytest
from tools.environments.ssh import SSHEnvironment
_SSH_HOST = os.getenv("TERMINAL_SSH_HOST", "")
_SSH_USER = os.getenv("TERMINAL_SSH_USER", "")
_SSH_PORT = int(os.getenv("TERMINAL_SSH_PORT", "22"))
_SSH_KEY = os.getenv("TERMINAL_SSH_KEY", "")
_has_ssh = bool(_SSH_HOST and _SSH_USER)
requires_ssh = pytest.mark.skipif(
not _has_ssh,
reason="TERMINAL_SSH_HOST / TERMINAL_SSH_USER not set",
)
def _run(command, task_id="ssh_test", **kwargs):
from tools.terminal_tool import terminal_tool
return json.loads(terminal_tool(command, task_id=task_id, **kwargs))
def _cleanup(task_id="ssh_test"):
from tools.terminal_tool import cleanup_vm
cleanup_vm(task_id)
class TestBuildSSHCommand:
@pytest.fixture(autouse=True)
def _mock_connection(self, monkeypatch):
monkeypatch.setattr("tools.environments.ssh.subprocess.run",
lambda *a, **k: subprocess.CompletedProcess([], 0))
monkeypatch.setattr("tools.environments.ssh.subprocess.Popen",
lambda *a, **k: MagicMock(stdout=iter([]),
stderr=iter([]),
stdin=MagicMock()))
monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None)
def test_base_flags(self):
env = SSHEnvironment(host="h", user="u")
cmd = " ".join(env._build_ssh_command())
for flag in ("ControlMaster=auto", "ControlPersist=300",
"BatchMode=yes", "StrictHostKeyChecking=accept-new"):
assert flag in cmd
def test_custom_port(self):
env = SSHEnvironment(host="h", user="u", port=2222)
cmd = env._build_ssh_command()
assert "-p" in cmd and "2222" in cmd
def test_key_path(self):
env = SSHEnvironment(host="h", user="u", key_path="/k")
cmd = env._build_ssh_command()
assert "-i" in cmd and "/k" in cmd
def test_user_host_suffix(self):
env = SSHEnvironment(host="h", user="u")
assert env._build_ssh_command()[-1] == "u@h"
class TestTerminalToolConfig:
def test_ssh_persistent_default_false(self, monkeypatch):
monkeypatch.delenv("TERMINAL_SSH_PERSISTENT", raising=False)
from tools.terminal_tool import _get_env_config
assert _get_env_config()["ssh_persistent"] is False
def test_ssh_persistent_true(self, monkeypatch):
monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true")
from tools.terminal_tool import _get_env_config
assert _get_env_config()["ssh_persistent"] is True
def _setup_ssh_env(monkeypatch, persistent: bool):
monkeypatch.setenv("TERMINAL_ENV", "ssh")
monkeypatch.setenv("TERMINAL_SSH_HOST", _SSH_HOST)
monkeypatch.setenv("TERMINAL_SSH_USER", _SSH_USER)
monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true" if persistent else "false")
if _SSH_PORT != 22:
monkeypatch.setenv("TERMINAL_SSH_PORT", str(_SSH_PORT))
if _SSH_KEY:
monkeypatch.setenv("TERMINAL_SSH_KEY", _SSH_KEY)
@requires_ssh
class TestOneShotSSH:
@pytest.fixture(autouse=True)
def _setup(self, monkeypatch):
_setup_ssh_env(monkeypatch, persistent=False)
yield
_cleanup()
def test_echo(self):
r = _run("echo hello")
assert r["exit_code"] == 0
assert "hello" in r["output"]
def test_exit_code(self):
r = _run("exit 42")
assert r["exit_code"] == 42
def test_state_does_not_persist(self):
_run("export HERMES_ONESHOT_TEST=yes")
r = _run("echo $HERMES_ONESHOT_TEST")
assert r["output"].strip() == ""
@requires_ssh
class TestPersistentSSH:
@pytest.fixture(autouse=True)
def _setup(self, monkeypatch):
_setup_ssh_env(monkeypatch, persistent=True)
yield
_cleanup()
def test_echo(self):
r = _run("echo hello-persistent")
assert r["exit_code"] == 0
assert "hello-persistent" in r["output"]
def test_env_var_persists(self):
_run("export HERMES_PERSIST_TEST=works")
r = _run("echo $HERMES_PERSIST_TEST")
assert r["output"].strip() == "works"
def test_cwd_persists(self):
_run("cd /tmp")
r = _run("pwd")
assert r["output"].strip() == "/tmp"
def test_exit_code(self):
r = _run("(exit 42)")
assert r["exit_code"] == 42
def test_stderr(self):
r = _run("echo oops >&2")
assert r["exit_code"] == 0
assert "oops" in r["output"]
def test_multiline_output(self):
r = _run("echo a; echo b; echo c")
lines = r["output"].strip().splitlines()
assert lines == ["a", "b", "c"]
def test_timeout_then_recovery(self):
r = _run("sleep 999", timeout=2)
assert r["exit_code"] == 124
r = _run("echo alive")
assert r["exit_code"] == 0
assert "alive" in r["output"]
def test_large_output(self):
r = _run("seq 1 1000")
assert r["exit_code"] == 0
lines = r["output"].strip().splitlines()
assert len(lines) == 1000
assert lines[0] == "1"
assert lines[-1] == "1000"

View File

@@ -315,6 +315,23 @@ class TestEnsureInstalled:
mock_thread.start.assert_called_once()
_tirith_mod._resolved_path = None
@patch("tools.tirith_security._load_security_config")
def test_startup_prefetch_can_suppress_install_failure_logs(self, mock_cfg):
mock_cfg.return_value = {"tirith_enabled": True, "tirith_path": "tirith",
"tirith_timeout": 5, "tirith_fail_open": True}
_tirith_mod._resolved_path = None
with patch("tools.tirith_security.shutil.which", return_value=None), \
patch("tools.tirith_security._hermes_bin_dir", return_value="/nonexistent"), \
patch("tools.tirith_security._is_install_failed_on_disk", return_value=False), \
patch("tools.tirith_security.threading.Thread") as MockThread:
mock_thread = MagicMock()
MockThread.return_value = mock_thread
result = ensure_installed(log_failures=False)
assert result is None
assert MockThread.call_args.kwargs["kwargs"] == {"log_failures": False}
mock_thread.start.assert_called_once()
_tirith_mod._resolved_path = None
# ---------------------------------------------------------------------------
# Failed download caches the miss (Finding #1)
@@ -516,6 +533,22 @@ class TestCosignVerification:
assert path is None
assert reason == "cosign_missing"
@patch("tools.tirith_security.logger.debug")
@patch("tools.tirith_security.logger.warning")
@patch("tools.tirith_security.shutil.which", return_value=None)
@patch("tools.tirith_security._download_file")
@patch("tools.tirith_security._detect_target", return_value="aarch64-apple-darwin")
def test_install_quiet_mode_downgrades_cosign_missing_log(self, mock_target, mock_dl,
mock_which, mock_warning,
mock_debug):
"""Startup prefetch should not surface cosign-missing as a warning."""
from tools.tirith_security import _install_tirith
path, reason = _install_tirith(log_failures=False)
assert path is None
assert reason == "cosign_missing"
mock_warning.assert_not_called()
mock_debug.assert_called()
@patch("tools.tirith_security._verify_cosign", return_value=None)
@patch("tools.tirith_security.shutil.which", return_value="/usr/local/bin/cosign")
@patch("tools.tirith_security._download_file")

View File

@@ -92,10 +92,17 @@ def _run_git(
shadow_repo: Path,
working_dir: str,
timeout: int = _GIT_TIMEOUT,
allowed_returncodes: Optional[Set[int]] = None,
) -> tuple:
"""Run a git command against the shadow repo. Returns (ok, stdout, stderr)."""
"""Run a git command against the shadow repo. Returns (ok, stdout, stderr).
``allowed_returncodes`` suppresses error logging for known/expected non-zero
exits while preserving the normal ``ok = (returncode == 0)`` contract.
Example: ``git diff --cached --quiet`` returns 1 when changes exist.
"""
env = _git_env(shadow_repo, working_dir)
cmd = ["git"] + list(args)
allowed_returncodes = allowed_returncodes or set()
try:
result = subprocess.run(
cmd,
@@ -108,7 +115,7 @@ def _run_git(
ok = result.returncode == 0
stdout = result.stdout.strip()
stderr = result.stderr.strip()
if not ok:
if not ok and result.returncode not in allowed_returncodes:
logger.error(
"Git command failed: %s (rc=%d) stderr=%s",
" ".join(cmd), result.returncode, stderr,
@@ -381,7 +388,10 @@ class CheckpointManager:
# Check if there's anything to commit
ok_diff, diff_out, _ = _run_git(
["diff", "--cached", "--quiet"], shadow, working_dir,
["diff", "--cached", "--quiet"],
shadow,
working_dir,
allowed_returncodes={1},
)
if ok_diff:
# No changes to commit

View File

@@ -1,6 +1,5 @@
"""Local execution environment with interrupt support and non-blocking I/O."""
import glob
import os
import platform
import shutil
@@ -12,8 +11,6 @@ import time
_IS_WINDOWS = platform.system() == "Windows"
from tools.environments.base import BaseEnvironment
from tools.environments.persistent_shell import PersistentShellMixin
from tools.interrupt import is_interrupted
# Unique marker to isolate real command output from shell init/exit noise.
# printf (no trailing newline) keeps the boundaries clean for splitting.
@@ -247,25 +244,6 @@ def _clean_shell_noise(output: str) -> str:
return result
_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
def _make_run_env(env: dict) -> dict:
"""Build a run environment with a sane PATH and provider-var stripping."""
merged = dict(os.environ | env)
run_env = {}
for k, v in merged.items():
if k.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
real_key = k[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):]
run_env[real_key] = v
elif k not in _HERMES_PROVIDER_ENV_BLOCKLIST:
run_env[k] = v
existing_path = run_env.get("PATH", "")
if "/usr/bin" not in existing_path.split(":"):
run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH
return run_env
def _extract_fenced_output(raw: str) -> str:
"""Extract real command output from between fence markers.
@@ -290,7 +268,7 @@ def _extract_fenced_output(raw: str) -> str:
return raw[start:last]
class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
class LocalEnvironment(BaseEnvironment):
"""Run commands directly on the host machine.
Features:
@@ -299,66 +277,24 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
- stdin_data support for piping content (bypasses ARG_MAX limits)
- sudo -S transform via SUDO_PASSWORD env var
- Uses interactive login shell so full user env is available
- Optional persistent shell mode (cwd/env vars survive across calls)
"""
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None,
persistent: bool = False):
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
self.persistent = persistent
if self.persistent:
self._init_persistent_shell()
@property
def _temp_prefix(self) -> str:
return f"/tmp/hermes-local-{self._session_id}"
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
from tools.terminal_tool import _interrupt_event
def _spawn_shell_process(self) -> subprocess.Popen:
user_shell = _find_bash()
run_env = _make_run_env(self.env)
return subprocess.Popen(
[user_shell, "-l"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True,
env=run_env,
preexec_fn=None if _IS_WINDOWS else os.setsid,
)
def _read_temp_files(self, *paths: str) -> list[str]:
results = []
for path in paths:
if os.path.exists(path):
with open(path) as f:
results.append(f.read())
else:
results.append("")
return results
def _kill_shell_children(self):
if self._shell_pid is None:
return
try:
subprocess.run(
["pkill", "-P", str(self._shell_pid)],
capture_output=True, timeout=5,
)
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
def _cleanup_temp_files(self):
for f in glob.glob(f"{self._temp_prefix}-*"):
if os.path.exists(f):
os.remove(f)
def _execute_oneshot(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
work_dir = cwd or self.cwd or os.getcwd()
effective_timeout = timeout or self.timeout
exec_command, sudo_stdin = self._prepare_command(command)
# Merge the sudo password (if any) with caller-supplied stdin_data.
# sudo -S reads exactly one line (the password) then passes the rest
# of stdin to the child, so prepending is safe even when stdin_data
# is also present.
if sudo_stdin is not None and stdin_data is not None:
effective_stdin = sudo_stdin + stdin_data
elif sudo_stdin is not None:
@@ -366,87 +302,110 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
else:
effective_stdin = stdin_data
user_shell = _find_bash()
fenced_cmd = (
f"printf '{_OUTPUT_FENCE}';"
f" {exec_command};"
f" __hermes_rc=$?;"
f" printf '{_OUTPUT_FENCE}';"
f" exit $__hermes_rc"
)
run_env = _make_run_env(self.env)
try:
# The fence wrapper uses bash syntax (semicolons, $?, printf).
# Always use bash for the wrapper — NOT $SHELL which could be
# fish, zsh, or another shell with incompatible syntax.
# The -lic flags source rc files so tools like nvm/pyenv work.
user_shell = _find_bash()
# Wrap with output fences so we can later extract the real
# command output and discard shell init/exit noise.
fenced_cmd = (
f"printf '{_OUTPUT_FENCE}';"
f" {exec_command};"
f" __hermes_rc=$?;"
f" printf '{_OUTPUT_FENCE}';"
f" exit $__hermes_rc"
)
# Ensure PATH always includes standard dirs — systemd services
# and some terminal multiplexers inherit a minimal PATH.
_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
# Strip Hermes-managed provider/tool/gateway vars so external CLIs
# are not silently misrouted or handed Hermes secrets. Callers that
# truly need a blocked var can opt in by prefixing the key with
# _HERMES_FORCE_ in self.env (e.g. _HERMES_FORCE_OPENAI_API_KEY).
run_env = _sanitize_subprocess_env(os.environ, self.env)
existing_path = run_env.get("PATH", "")
if "/usr/bin" not in existing_path.split(":"):
run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH
proc = subprocess.Popen(
[user_shell, "-lic", fenced_cmd],
text=True,
cwd=work_dir,
env=run_env,
encoding="utf-8",
errors="replace",
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
preexec_fn=None if _IS_WINDOWS else os.setsid,
)
proc = subprocess.Popen(
[user_shell, "-lic", fenced_cmd],
text=True,
cwd=work_dir,
env=run_env,
encoding="utf-8",
errors="replace",
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
preexec_fn=None if _IS_WINDOWS else os.setsid,
)
if effective_stdin is not None:
def _write_stdin():
if effective_stdin is not None:
def _write_stdin():
try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except (BrokenPipeError, OSError):
pass
threading.Thread(target=_write_stdin, daemon=True).start()
_output_chunks: list[str] = []
def _drain_stdout():
try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except (BrokenPipeError, OSError):
for line in proc.stdout:
_output_chunks.append(line)
except ValueError:
pass
threading.Thread(target=_write_stdin, daemon=True).start()
finally:
try:
proc.stdout.close()
except Exception:
pass
_output_chunks: list[str] = []
reader = threading.Thread(target=_drain_stdout, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
def _drain_stdout():
try:
for line in proc.stdout:
_output_chunks.append(line)
except ValueError:
pass
finally:
try:
proc.stdout.close()
except Exception:
pass
while proc.poll() is None:
if _interrupt_event.is_set():
try:
if _IS_WINDOWS:
proc.terminate()
else:
pgid = os.getpgid(proc.pid)
os.killpg(pgid, signal.SIGTERM)
try:
proc.wait(timeout=1.0)
except subprocess.TimeoutExpired:
os.killpg(pgid, signal.SIGKILL)
except (ProcessLookupError, PermissionError):
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
"returncode": 130,
}
if time.monotonic() > deadline:
try:
if _IS_WINDOWS:
proc.terminate()
else:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except (ProcessLookupError, PermissionError):
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
reader = threading.Thread(target=_drain_stdout, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
reader.join(timeout=5)
output = _extract_fenced_output("".join(_output_chunks))
return {"output": output, "returncode": proc.returncode}
while proc.poll() is None:
if is_interrupted():
try:
if _IS_WINDOWS:
proc.terminate()
else:
pgid = os.getpgid(proc.pid)
os.killpg(pgid, signal.SIGTERM)
try:
proc.wait(timeout=1.0)
except subprocess.TimeoutExpired:
os.killpg(pgid, signal.SIGKILL)
except (ProcessLookupError, PermissionError):
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
"returncode": 130,
}
if time.monotonic() > deadline:
try:
if _IS_WINDOWS:
proc.terminate()
else:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except (ProcessLookupError, PermissionError):
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
except Exception as e:
return {"output": f"Execution error: {str(e)}", "returncode": 1}
reader.join(timeout=5)
output = _extract_fenced_output("".join(_output_chunks))
return {"output": output, "returncode": proc.returncode}
def cleanup(self):
pass

View File

@@ -1,272 +0,0 @@
"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells."""
import logging
import shlex
import subprocess
import threading
import time
import uuid
from abc import abstractmethod
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
class PersistentShellMixin:
"""Mixin that adds persistent shell capability to any BaseEnvironment.
Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``,
``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``.
"""
persistent: bool
@abstractmethod
def _spawn_shell_process(self) -> subprocess.Popen: ...
@abstractmethod
def _read_temp_files(self, *paths: str) -> list[str]: ...
@abstractmethod
def _kill_shell_children(self): ...
@abstractmethod
def _execute_oneshot(self, command: str, cwd: str, *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict: ...
@abstractmethod
def _cleanup_temp_files(self): ...
_session_id: str = ""
_poll_interval: float = 0.01
@property
def _temp_prefix(self) -> str:
return f"/tmp/hermes-persistent-{self._session_id}"
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def _init_persistent_shell(self):
self._shell_lock = threading.Lock()
self._shell_proc: subprocess.Popen | None = None
self._shell_alive: bool = False
self._shell_pid: int | None = None
self._session_id = uuid.uuid4().hex[:12]
p = self._temp_prefix
self._pshell_stdout = f"{p}-stdout"
self._pshell_stderr = f"{p}-stderr"
self._pshell_status = f"{p}-status"
self._pshell_cwd = f"{p}-cwd"
self._pshell_pid_file = f"{p}-pid"
self._shell_proc = self._spawn_shell_process()
self._shell_alive = True
self._drain_thread = threading.Thread(
target=self._drain_shell_output, daemon=True,
)
self._drain_thread.start()
init_script = (
f"export TERM=${{TERM:-dumb}}\n"
f"touch {self._pshell_stdout} {self._pshell_stderr} "
f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n"
f"echo $$ > {self._pshell_pid_file}\n"
f"pwd > {self._pshell_cwd}\n"
)
self._send_to_shell(init_script)
deadline = time.monotonic() + 3.0
while time.monotonic() < deadline:
pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip()
if pid_str.isdigit():
self._shell_pid = int(pid_str)
break
time.sleep(0.05)
else:
logger.warning("Could not read persistent shell PID")
self._shell_pid = None
if self._shell_pid:
logger.info(
"Persistent shell started (session=%s, pid=%d)",
self._session_id, self._shell_pid,
)
reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip()
if reported_cwd:
self.cwd = reported_cwd
def _cleanup_persistent_shell(self):
if self._shell_proc is None:
return
if self._session_id:
self._cleanup_temp_files()
try:
self._shell_proc.stdin.close()
except Exception:
pass
try:
self._shell_proc.terminate()
self._shell_proc.wait(timeout=3)
except subprocess.TimeoutExpired:
self._shell_proc.kill()
self._shell_alive = False
self._shell_proc = None
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
self._drain_thread.join(timeout=1.0)
# ------------------------------------------------------------------
# execute() / cleanup() — shared dispatcher, subclasses inherit
# ------------------------------------------------------------------
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if self.persistent:
return self._execute_persistent(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
def cleanup(self):
if self.persistent:
self._cleanup_persistent_shell()
# ------------------------------------------------------------------
# Shell I/O
# ------------------------------------------------------------------
def _drain_shell_output(self):
try:
for _ in self._shell_proc.stdout:
pass
except Exception:
pass
self._shell_alive = False
def _send_to_shell(self, text: str):
if not self._shell_alive or self._shell_proc is None:
return
try:
self._shell_proc.stdin.write(text)
self._shell_proc.stdin.flush()
except (BrokenPipeError, OSError):
self._shell_alive = False
def _read_persistent_output(self) -> tuple[str, int, str]:
stdout, stderr, status_raw, cwd = self._read_temp_files(
self._pshell_stdout, self._pshell_stderr,
self._pshell_status, self._pshell_cwd,
)
output = self._merge_output(stdout, stderr)
status = status_raw.strip()
if ":" in status:
status = status.split(":", 1)[1]
try:
exit_code = int(status.strip())
except ValueError:
exit_code = 1
return output, exit_code, cwd.strip()
# ------------------------------------------------------------------
# Execution
# ------------------------------------------------------------------
def _execute_persistent(self, command: str, cwd: str, *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if not self._shell_alive:
logger.info("Persistent shell died, restarting...")
self._init_persistent_shell()
exec_command, sudo_stdin = self._prepare_command(command)
effective_timeout = timeout or self.timeout
if stdin_data or sudo_stdin:
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
with self._shell_lock:
return self._execute_persistent_locked(
exec_command, cwd, effective_timeout,
)
def _execute_persistent_locked(self, command: str, cwd: str,
timeout: int) -> dict:
work_dir = cwd or self.cwd
cmd_id = uuid.uuid4().hex[:8]
truncate = (
f": > {self._pshell_stdout}\n"
f": > {self._pshell_stderr}\n"
f": > {self._pshell_status}\n"
)
self._send_to_shell(truncate)
escaped = command.replace("'", "'\\''")
ipc_script = (
f"cd {shlex.quote(work_dir)}\n"
f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n"
f"__EC=$?\n"
f"pwd > {self._pshell_cwd}\n"
f"echo {cmd_id}:$__EC > {self._pshell_status}\n"
)
self._send_to_shell(ipc_script)
deadline = time.monotonic() + timeout
poll_interval = self._poll_interval
while True:
if is_interrupted():
self._kill_shell_children()
output, _, _ = self._read_persistent_output()
return {
"output": output + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
self._kill_shell_children()
output, _, _ = self._read_persistent_output()
if output:
return {
"output": output + f"\n[Command timed out after {timeout}s]",
"returncode": 124,
}
return self._timeout_result(timeout)
if not self._shell_alive:
return {
"output": "Persistent shell died during execution",
"returncode": 1,
}
status_content = self._read_temp_files(self._pshell_status)[0].strip()
if status_content.startswith(cmd_id + ":"):
break
time.sleep(poll_interval)
output, exit_code, new_cwd = self._read_persistent_output()
if new_cwd:
self.cwd = new_cwd
return {"output": output, "returncode": exit_code}
@staticmethod
def _merge_output(stdout: str, stderr: str) -> str:
parts = []
if stdout.strip():
parts.append(stdout.rstrip("\n"))
if stderr.strip():
parts.append(stderr.rstrip("\n"))
return "\n".join(parts)

View File

@@ -8,13 +8,12 @@ import time
from pathlib import Path
from tools.environments.base import BaseEnvironment
from tools.environments.persistent_shell import PersistentShellMixin
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
class SSHEnvironment(BaseEnvironment):
"""Run commands on a remote machine over SSH.
Uses SSH ControlMaster for connection persistence so subsequent
@@ -23,33 +22,22 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
Foreground commands are interruptible: the local ssh process is killed
and a remote kill is attempted over the ControlMaster socket.
When ``persistent=True``, a single long-lived bash shell is kept alive
over SSH and state (cwd, env vars, shell variables) persists across
``execute()`` calls. Output capture uses file-based IPC on the remote
host (stdout/stderr/exit-code written to temp files, polled via fast
ControlMaster one-shot reads).
"""
def __init__(self, host: str, user: str, cwd: str = "~",
timeout: int = 60, port: int = 22, key_path: str = "",
persistent: bool = False):
timeout: int = 60, port: int = 22, key_path: str = ""):
super().__init__(cwd=cwd, timeout=timeout)
self.host = host
self.user = user
self.port = port
self.key_path = key_path
self.persistent = persistent
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
self.control_dir.mkdir(parents=True, exist_ok=True)
self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock"
self._establish_connection()
if self.persistent:
self._init_persistent_shell()
def _build_ssh_command(self, extra_args: list | None = None) -> list:
def _build_ssh_command(self, extra_args: list = None) -> list:
cmd = ["ssh"]
cmd.extend(["-o", f"ControlPath={self.control_socket}"])
cmd.extend(["-o", "ControlMaster=auto"])
@@ -77,76 +65,15 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
except subprocess.TimeoutExpired:
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
_poll_interval: float = 0.15
@property
def _temp_prefix(self) -> str:
return f"/tmp/hermes-ssh-{self._session_id}"
def _spawn_shell_process(self) -> subprocess.Popen:
cmd = self._build_ssh_command()
cmd.append("bash -l")
return subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True,
)
def _read_temp_files(self, *paths: str) -> list[str]:
if len(paths) == 1:
cmd = self._build_ssh_command()
cmd.append(f"cat {paths[0]} 2>/dev/null")
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=10,
)
return [result.stdout]
except (subprocess.TimeoutExpired, OSError):
return [""]
delim = f"__HERMES_SEP_{self._session_id}__"
script = "; ".join(
f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths
)
cmd = self._build_ssh_command()
cmd.append(script)
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=10,
)
parts = result.stdout.split(delim + "\n")
return [parts[i] if i < len(parts) else "" for i in range(len(paths))]
except (subprocess.TimeoutExpired, OSError):
return [""] * len(paths)
def _kill_shell_children(self):
if self._shell_pid is None:
return
cmd = self._build_ssh_command()
cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true")
try:
subprocess.run(cmd, capture_output=True, timeout=5)
except (subprocess.TimeoutExpired, OSError):
pass
def _cleanup_temp_files(self):
cmd = self._build_ssh_command()
cmd.append(f"rm -f {self._temp_prefix}-*")
try:
subprocess.run(cmd, capture_output=True, timeout=5)
except (subprocess.TimeoutExpired, OSError):
pass
def _execute_oneshot(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
work_dir = cwd or self.cwd
exec_command, sudo_stdin = self._prepare_command(command)
wrapped = f'cd {work_dir} && {exec_command}'
effective_timeout = timeout or self.timeout
# Merge sudo password (if any) with caller-supplied stdin_data.
if sudo_stdin is not None and stdin_data is not None:
effective_stdin = sudo_stdin + stdin_data
elif sudo_stdin is not None:
@@ -155,60 +82,66 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
effective_stdin = stdin_data
cmd = self._build_ssh_command()
cmd.append(wrapped)
cmd.extend(["bash", "-c", wrapped])
kwargs = self._build_run_kwargs(timeout, effective_stdin)
kwargs.pop("timeout", None)
_output_chunks = []
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
text=True,
)
try:
kwargs = self._build_run_kwargs(timeout, effective_stdin)
# Remove timeout from kwargs -- we handle it in the poll loop
kwargs.pop("timeout", None)
if effective_stdin:
try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except (BrokenPipeError, OSError):
pass
_output_chunks = []
def _drain():
try:
for line in proc.stdout:
_output_chunks.append(line)
except Exception:
pass
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
text=True,
)
reader = threading.Thread(target=_drain, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
while proc.poll() is None:
if is_interrupted():
proc.terminate()
if effective_stdin:
try:
proc.wait(timeout=1)
except subprocess.TimeoutExpired:
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
proc.stdin.write(effective_stdin)
proc.stdin.close()
except Exception:
pass
reader.join(timeout=5)
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
def _drain():
try:
for line in proc.stdout:
_output_chunks.append(line)
except Exception:
pass
reader = threading.Thread(target=_drain, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
while proc.poll() is None:
if is_interrupted():
proc.terminate()
try:
proc.wait(timeout=1)
except subprocess.TimeoutExpired:
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
reader.join(timeout=5)
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
except Exception as e:
return {"output": f"SSH execution error: {str(e)}", "returncode": 1}
def cleanup(self):
super().cleanup()
if self.control_socket.exists():
try:
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
"""File Tools Module - LLM agent file manipulation tools."""
import errno
import json
import logging
import os
@@ -11,6 +12,18 @@ from agent.redact import redact_sensitive_text
logger = logging.getLogger(__name__)
_EXPECTED_WRITE_ERRNOS = {errno.EACCES, errno.EPERM, errno.EROFS}
def _is_expected_write_exception(exc: Exception) -> bool:
"""Return True for expected write denials that should not hit error logs."""
if isinstance(exc, PermissionError):
return True
if isinstance(exc, OSError) and exc.errno in _EXPECTED_WRITE_ERRNOS:
return True
return False
_file_ops_lock = threading.Lock()
_file_ops_cache: dict = {}
@@ -101,31 +114,12 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
"container_persistent": config.get("container_persistent", True),
"docker_volumes": config.get("docker_volumes", []),
}
ssh_config = None
if env_type == "ssh":
ssh_config = {
"host": config.get("ssh_host", ""),
"user": config.get("ssh_user", ""),
"port": config.get("ssh_port", 22),
"key": config.get("ssh_key", ""),
"persistent": config.get("ssh_persistent", False),
}
local_config = None
if env_type == "local":
local_config = {
"persistent": config.get("local_persistent", False),
}
terminal_env = _create_environment(
env_type=env_type,
image=image,
cwd=cwd,
timeout=config["timeout"],
ssh_config=ssh_config,
container_config=container_config,
local_config=local_config,
task_id=task_id,
)
@@ -257,7 +251,10 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
result = file_ops.write_file(path, content)
return json.dumps(result.to_dict(), ensure_ascii=False)
except Exception as e:
logger.error("write_file error: %s: %s", type(e).__name__, e)
if _is_expected_write_exception(e):
logger.debug("write_file expected denial: %s: %s", type(e).__name__, e)
else:
logger.error("write_file error: %s: %s", type(e).__name__, e, exc_info=True)
return json.dumps({"error": str(e)}, ensure_ascii=False)

View File

@@ -471,8 +471,6 @@ def _get_env_config() -> Dict[str, Any]:
# is running inside the container/remote).
if env_type == "local":
default_cwd = os.getcwd()
elif env_type == "ssh":
default_cwd = "~"
else:
default_cwd = "/root"
@@ -505,8 +503,6 @@ def _get_env_config() -> Dict[str, Any]:
"ssh_user": os.getenv("TERMINAL_SSH_USER", ""),
"ssh_port": _parse_env_var("TERMINAL_SSH_PORT", "22"),
"ssh_key": os.getenv("TERMINAL_SSH_KEY", ""),
"ssh_persistent": os.getenv("TERMINAL_SSH_PERSISTENT", "false").lower() in ("true", "1", "yes"),
"local_persistent": os.getenv("TERMINAL_LOCAL_PERSISTENT", "false").lower() in ("true", "1", "yes"),
# Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh)
"container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"),
"container_memory": _parse_env_var("TERMINAL_CONTAINER_MEMORY", "5120"), # MB (default 5GB)
@@ -518,7 +514,6 @@ def _get_env_config() -> Dict[str, Any]:
def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
ssh_config: dict = None, container_config: dict = None,
local_config: dict = None,
task_id: str = "default"):
"""
Create an execution environment from mini-swe-agent.
@@ -543,9 +538,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
volumes = cc.get("docker_volumes", [])
if env_type == "local":
lc = local_config or {}
return _LocalEnvironment(cwd=cwd, timeout=timeout,
persistent=lc.get("persistent", False))
return _LocalEnvironment(cwd=cwd, timeout=timeout)
elif env_type == "docker":
return _DockerEnvironment(
@@ -601,7 +594,6 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
key_path=ssh_config.get("key", ""),
cwd=cwd,
timeout=timeout,
persistent=ssh_config.get("persistent", False),
)
else:
@@ -931,7 +923,6 @@ def terminal_tool(
"user": config.get("ssh_user", ""),
"port": config.get("ssh_port", 22),
"key": config.get("ssh_key", ""),
"persistent": config.get("ssh_persistent", False),
}
container_config = None
@@ -944,12 +935,6 @@ def terminal_tool(
"docker_volumes": config.get("docker_volumes", []),
}
local_config = None
if env_type == "local":
local_config = {
"persistent": config.get("local_persistent", False),
}
new_env = _create_environment(
env_type=env_type,
image=image,
@@ -957,7 +942,6 @@ def terminal_tool(
timeout=effective_timeout,
ssh_config=ssh_config,
container_config=container_config,
local_config=local_config,
task_id=effective_task_id,
)
except ImportError as e:

View File

@@ -279,7 +279,7 @@ def _verify_checksum(archive_path: str, checksums_path: str, archive_name: str)
return True
def _install_tirith() -> tuple[str | None, str]:
def _install_tirith(*, log_failures: bool = True) -> tuple[str | None, str]:
"""Download and install tirith to $HERMES_HOME/bin/tirith.
Verifies provenance via cosign and SHA-256 checksum.
@@ -287,6 +287,8 @@ def _install_tirith() -> tuple[str | None, str]:
failure_reason is a short tag used by the disk marker to decide if the
failure is retryable (e.g. "cosign_missing" clears when cosign appears).
"""
log = logger.warning if log_failures else logger.debug
target = _detect_target()
if not target:
logger.info("tirith auto-install: unsupported platform %s/%s",
@@ -309,7 +311,7 @@ def _install_tirith() -> tuple[str | None, str]:
_download_file(f"{base_url}/{archive_name}", archive_path)
_download_file(f"{base_url}/checksums.txt", checksums_path)
except Exception as exc:
logger.warning("tirith download failed: %s", exc)
log("tirith download failed: %s", exc)
return None, "download_failed"
# Cosign provenance verification is mandatory for auto-install.
@@ -320,25 +322,25 @@ def _install_tirith() -> tuple[str | None, str]:
_download_file(f"{base_url}/checksums.txt.sig", sig_path)
_download_file(f"{base_url}/checksums.txt.pem", cert_path)
except Exception as exc:
logger.warning("tirith install skipped: cosign artifacts unavailable (%s). "
"Install tirith manually or install cosign for auto-install.", exc)
log("tirith install skipped: cosign artifacts unavailable (%s). "
"Install tirith manually or install cosign for auto-install.", exc)
return None, "cosign_artifacts_unavailable"
# Check cosign availability before attempting verification so we can
# distinguish "not installed" (retryable) from "installed but broken."
if not shutil.which("cosign"):
logger.warning("tirith install skipped: cosign not found on PATH. "
"Install cosign for auto-install, or install tirith manually.")
log("tirith install skipped: cosign not found on PATH. "
"Install cosign for auto-install, or install tirith manually.")
return None, "cosign_missing"
cosign_result = _verify_cosign(checksums_path, sig_path, cert_path)
if cosign_result is not True:
# False = verification rejected, None = execution failure (timeout/OSError)
if cosign_result is None:
logger.warning("tirith install aborted: cosign execution failed")
log("tirith install aborted: cosign execution failed")
return None, "cosign_exec_failed"
else:
logger.warning("tirith install aborted: cosign provenance verification failed")
log("tirith install aborted: cosign provenance verification failed")
return None, "cosign_verification_failed"
if not _verify_checksum(archive_path, checksums_path, archive_name):
@@ -354,7 +356,7 @@ def _install_tirith() -> tuple[str | None, str]:
tar.extract(member, tmpdir)
break
else:
logger.warning("tirith binary not found in archive")
log("tirith binary not found in archive")
return None, "binary_not_in_archive"
src = os.path.join(tmpdir, "tirith")
@@ -473,7 +475,7 @@ def _resolve_tirith_path(configured_path: str) -> str:
return expanded
def _background_install():
def _background_install(*, log_failures: bool = True):
"""Background thread target: download and install tirith."""
global _resolved_path, _install_failure_reason
with _install_lock:
@@ -494,7 +496,7 @@ def _background_install():
_install_failure_reason = ""
return
installed, reason = _install_tirith()
installed, reason = _install_tirith(log_failures=log_failures)
if installed:
_resolved_path = installed
_install_failure_reason = ""
@@ -505,7 +507,7 @@ def _background_install():
_mark_install_failed(reason)
def ensure_installed():
def ensure_installed(*, log_failures: bool = True):
"""Ensure tirith is available, downloading in background if needed.
Quick PATH/local checks are synchronous; network download runs in a
@@ -578,7 +580,10 @@ def ensure_installed():
# Need to download — launch background thread so startup doesn't block
if _install_thread is None or not _install_thread.is_alive():
_install_thread = threading.Thread(
target=_background_install, daemon=True)
target=_background_install,
kwargs={"log_failures": log_failures},
daemon=True,
)
_install_thread.start()
return None # Not available yet; commands will fail-open until ready

View File

@@ -50,6 +50,8 @@ def atomic_json_write(
os.fsync(f.fileno())
os.replace(tmp_path, path)
except BaseException:
# Intentionally catch BaseException so temp-file cleanup still runs for
# KeyboardInterrupt/SystemExit before re-raising the original signal.
try:
os.unlink(tmp_path)
except OSError:
@@ -96,6 +98,8 @@ def atomic_yaml_write(
os.fsync(f.fileno())
os.replace(tmp_path, path)
except BaseException:
# Match atomic_json_write: cleanup must also happen for process-level
# interruptions before we re-raise them.
try:
os.unlink(tmp_path)
except OSError: