Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fcf64d5283 | |||
| 8bbafdf3a6 | |||
| 04ee0ec0bc | |||
| b7903bca41 | |||
| 20e94662cc | |||
| 6ed3f9ca80 |
@@ -1142,13 +1142,7 @@ def resolve_provider_client(
|
||||
if provider == "codex":
|
||||
provider = "openai-codex"
|
||||
if provider == "main":
|
||||
# Resolve to the user's actual main provider so named custom providers
|
||||
# and non-aggregator providers (DeepSeek, Alibaba, etc.) work correctly.
|
||||
main_prov = _read_main_provider()
|
||||
if main_prov and main_prov not in ("auto", "main", ""):
|
||||
provider = main_prov
|
||||
else:
|
||||
provider = "custom"
|
||||
provider = "custom"
|
||||
|
||||
# ── Auto: try all providers in priority order ────────────────────
|
||||
if provider == "auto":
|
||||
@@ -1244,28 +1238,6 @@ def resolve_provider_client(
|
||||
"but no endpoint credentials found")
|
||||
return None, None
|
||||
|
||||
# ── Named custom providers (config.yaml custom_providers list) ───
|
||||
try:
|
||||
from hermes_cli.runtime_provider import _get_named_custom_provider
|
||||
custom_entry = _get_named_custom_provider(provider)
|
||||
if custom_entry:
|
||||
custom_base = custom_entry.get("base_url", "").strip()
|
||||
custom_key = custom_entry.get("api_key", "").strip() or "no-key-required"
|
||||
if custom_base:
|
||||
final_model = model or _read_main_model() or "gpt-4o-mini"
|
||||
client = OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
logger.debug(
|
||||
"resolve_provider_client: named custom provider %r (%s)",
|
||||
provider, final_model)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
logger.warning(
|
||||
"resolve_provider_client: named custom provider %r has no base_url",
|
||||
provider)
|
||||
return None, None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# ── API-key providers from PROVIDER_REGISTRY ─────────────────────
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY, resolve_api_key_provider_credentials
|
||||
@@ -1386,11 +1358,6 @@ def _normalize_vision_provider(provider: Optional[str]) -> str:
|
||||
if provider == "codex":
|
||||
return "openai-codex"
|
||||
if provider == "main":
|
||||
# Resolve to actual main provider — named custom providers and
|
||||
# non-aggregator providers need to pass through as their real name.
|
||||
main_prov = _read_main_provider()
|
||||
if main_prov and main_prov not in ("auto", "main", ""):
|
||||
return main_prov
|
||||
return "custom"
|
||||
return provider
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ from agent.usage_pricing import (
|
||||
format_duration_compact,
|
||||
format_token_count_compact,
|
||||
)
|
||||
from hermes_cli.banner import _format_context_length, format_banner_version_label
|
||||
from hermes_cli.banner import _format_context_length
|
||||
|
||||
_COMMAND_SPINNER_FRAMES = ("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏")
|
||||
|
||||
@@ -1036,44 +1036,21 @@ COMPACT_BANNER = """
|
||||
|
||||
def _build_compact_banner() -> str:
|
||||
"""Build a compact banner that fits the current terminal width."""
|
||||
try:
|
||||
from hermes_cli.skin_engine import get_active_skin
|
||||
_skin = get_active_skin()
|
||||
except Exception:
|
||||
_skin = None
|
||||
|
||||
skin_name = getattr(_skin, "name", "default") if _skin else "default"
|
||||
border_color = _skin.get_color("banner_border", "#FFD700") if _skin else "#FFD700"
|
||||
title_color = _skin.get_color("banner_title", "#FFBF00") if _skin else "#FFBF00"
|
||||
dim_color = _skin.get_color("banner_dim", "#B8860B") if _skin else "#B8860B"
|
||||
|
||||
if skin_name == "default":
|
||||
line1 = "⚕ NOUS HERMES - AI Agent Framework"
|
||||
tiny_line = "⚕ NOUS HERMES"
|
||||
else:
|
||||
agent_name = _skin.get_branding("agent_name", "Hermes Agent") if _skin else "Hermes Agent"
|
||||
line1 = f"{agent_name} - AI Agent Framework"
|
||||
tiny_line = agent_name
|
||||
|
||||
version_line = format_banner_version_label()
|
||||
|
||||
w = min(shutil.get_terminal_size().columns - 2, 88)
|
||||
w = min(shutil.get_terminal_size().columns - 2, 64)
|
||||
if w < 30:
|
||||
return f"\n[{title_color}]{tiny_line}[/] [dim {dim_color}]- Nous Research[/]\n"
|
||||
|
||||
return "\n[#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- Nous Research[/]\n"
|
||||
inner = w - 2 # inside the box border
|
||||
bar = "═" * w
|
||||
content_width = inner - 2
|
||||
|
||||
line1 = "⚕ NOUS HERMES - AI Agent Framework"
|
||||
line2 = "Messenger of the Digital Gods · Nous Research"
|
||||
# Truncate and pad to fit
|
||||
line1 = line1[:content_width].ljust(content_width)
|
||||
line2 = version_line[:content_width].ljust(content_width)
|
||||
|
||||
line1 = line1[:inner - 2].ljust(inner - 2)
|
||||
line2 = line2[:inner - 2].ljust(inner - 2)
|
||||
return (
|
||||
f"\n[bold {border_color}]╔{bar}╗[/]\n"
|
||||
f"[bold {border_color}]║[/] [{title_color}]{line1}[/] [bold {border_color}]║[/]\n"
|
||||
f"[bold {border_color}]║[/] [dim {dim_color}]{line2}[/] [bold {border_color}]║[/]\n"
|
||||
f"[bold {border_color}]╚{bar}╝[/]\n"
|
||||
f"\n[bold #FFD700]╔{bar}╗[/]\n"
|
||||
f"[bold #FFD700]║[/] [#FFBF00]{line1}[/] [bold #FFD700]║[/]\n"
|
||||
f"[bold #FFD700]║[/] [dim #B8860B]{line2}[/] [bold #FFD700]║[/]\n"
|
||||
f"[bold #FFD700]╚{bar}╝[/]\n"
|
||||
)
|
||||
|
||||
|
||||
@@ -2186,7 +2163,7 @@ class HermesCLI:
|
||||
)
|
||||
except Exception as exc:
|
||||
message = format_runtime_provider_error(exc)
|
||||
ChatConsole().print(f"[bold red]{message}[/]")
|
||||
self.console.print(f"[bold red]{message}[/]")
|
||||
return False
|
||||
|
||||
api_key = runtime.get("api_key")
|
||||
@@ -2401,7 +2378,7 @@ class HermesCLI:
|
||||
self._pending_title = None
|
||||
return True
|
||||
except Exception as e:
|
||||
ChatConsole().print(f"[bold red]Failed to initialize agent: {e}[/]")
|
||||
self.console.print(f"[bold red]Failed to initialize agent: {e}[/]")
|
||||
return False
|
||||
|
||||
def show_banner(self):
|
||||
@@ -4553,13 +4530,13 @@ class HermesCLI:
|
||||
if output:
|
||||
self.console.print(_rich_text_from_ansi(output))
|
||||
else:
|
||||
ChatConsole().print("[dim]Command returned no output[/]")
|
||||
self.console.print("[dim]Command returned no output[/]")
|
||||
except subprocess.TimeoutExpired:
|
||||
ChatConsole().print("[bold red]Quick command timed out (30s)[/]")
|
||||
self.console.print("[bold red]Quick command timed out (30s)[/]")
|
||||
except Exception as e:
|
||||
ChatConsole().print(f"[bold red]Quick command error: {e}[/]")
|
||||
self.console.print(f"[bold red]Quick command error: {e}[/]")
|
||||
else:
|
||||
ChatConsole().print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]")
|
||||
self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]")
|
||||
elif qcmd.get("type") == "alias":
|
||||
target = qcmd.get("target", "").strip()
|
||||
if target:
|
||||
@@ -4568,9 +4545,9 @@ class HermesCLI:
|
||||
aliased_command = f"{target} {user_args}".strip()
|
||||
return self.process_command(aliased_command)
|
||||
else:
|
||||
ChatConsole().print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]")
|
||||
self.console.print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]")
|
||||
else:
|
||||
ChatConsole().print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]")
|
||||
self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]")
|
||||
# Check for plugin-registered slash commands
|
||||
elif base_cmd.lstrip("/") in _get_plugin_cmd_handler_names():
|
||||
from hermes_cli.plugins import get_plugin_command_handler
|
||||
@@ -4595,7 +4572,7 @@ class HermesCLI:
|
||||
if hasattr(self, '_pending_input'):
|
||||
self._pending_input.put(msg)
|
||||
else:
|
||||
ChatConsole().print(f"[bold red]Failed to load skill for {base_cmd}[/]")
|
||||
self.console.print(f"[bold red]Failed to load skill for {base_cmd}[/]")
|
||||
else:
|
||||
# Prefix matching: if input uniquely identifies one command, execute it.
|
||||
# Matches against both built-in COMMANDS and installed skill commands so
|
||||
@@ -4656,14 +4633,14 @@ class HermesCLI:
|
||||
)
|
||||
|
||||
if not msg:
|
||||
ChatConsole().print("[bold red]Failed to load the bundled /plan skill[/]")
|
||||
self.console.print("[bold red]Failed to load the bundled /plan skill[/]")
|
||||
return
|
||||
|
||||
_cprint(f" 📝 Plan mode queued via skill. Markdown plan target: {plan_path}")
|
||||
if hasattr(self, '_pending_input'):
|
||||
self._pending_input.put(msg)
|
||||
else:
|
||||
ChatConsole().print("[bold red]Plan mode unavailable: input queue not initialized[/]")
|
||||
self.console.print("[bold red]Plan mode unavailable: input queue not initialized[/]")
|
||||
|
||||
def _handle_background_command(self, cmd: str):
|
||||
"""Handle /background <prompt> — run a prompt in a separate background session.
|
||||
|
||||
@@ -21,8 +21,6 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from model_tools import handle_function_call
|
||||
from tools.terminal_tool import get_active_env
|
||||
from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
# (e.g., the Modal/Docker/Daytona terminal backends). Running them in a separate
|
||||
@@ -140,7 +138,6 @@ class HermesAgentLoop:
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
budget_config: Optional["BudgetConfig"] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
@@ -157,11 +154,7 @@ class HermesAgentLoop:
|
||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
||||
Used for OpenRouter provider preferences, transforms, etc.
|
||||
e.g. {"provider": {"ignore": ["DeepInfra"]}}
|
||||
budget_config: Tool result persistence budget. Controls per-tool
|
||||
thresholds, per-turn aggregate budget, and preview size.
|
||||
If None, uses DEFAULT_BUDGET (current hardcoded values).
|
||||
"""
|
||||
from tools.budget_config import DEFAULT_BUDGET
|
||||
self.server = server
|
||||
self.tool_schemas = tool_schemas
|
||||
self.valid_tool_names = valid_tool_names
|
||||
@@ -170,7 +163,6 @@ class HermesAgentLoop:
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.extra_body = extra_body
|
||||
self.budget_config = budget_config or DEFAULT_BUDGET
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
||||
"""
|
||||
@@ -454,15 +446,8 @@ class HermesAgentLoop:
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Add tool response to conversation
|
||||
tc_id = tc.get("id", "") if isinstance(tc, dict) else tc.id
|
||||
tool_result = maybe_persist_tool_result(
|
||||
content=tool_result,
|
||||
tool_name=tool_name,
|
||||
tool_use_id=tc_id,
|
||||
env=get_active_env(self.task_id),
|
||||
config=self.budget_config,
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -471,14 +456,6 @@ class HermesAgentLoop:
|
||||
}
|
||||
)
|
||||
|
||||
num_tcs = len(assistant_msg.tool_calls)
|
||||
if num_tcs > 0:
|
||||
enforce_turn_budget(
|
||||
messages[-num_tcs:],
|
||||
env=get_active_env(self.task_id),
|
||||
config=self.budget_config,
|
||||
)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs",
|
||||
|
||||
@@ -1048,7 +1048,6 @@ class AgenticOPDEnv(HermesAgentBaseEnv):
|
||||
temperature=0.0,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path, PurePosixPath, PureWindowsPath
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
@@ -148,62 +148,6 @@ MODAL_INCOMPATIBLE_TASKS = {
|
||||
# Tar extraction helper
|
||||
# =============================================================================
|
||||
|
||||
def _normalize_tar_member_parts(member_name: str) -> list:
|
||||
"""Return safe path components for a tar member or raise ValueError."""
|
||||
normalized_name = member_name.replace("\\", "/")
|
||||
posix_path = PurePosixPath(normalized_name)
|
||||
windows_path = PureWindowsPath(member_name)
|
||||
|
||||
if (
|
||||
not normalized_name
|
||||
or posix_path.is_absolute()
|
||||
or windows_path.is_absolute()
|
||||
or windows_path.drive
|
||||
):
|
||||
raise ValueError(f"Unsafe archive member path: {member_name}")
|
||||
|
||||
parts = [part for part in posix_path.parts if part not in ("", ".")]
|
||||
if not parts or any(part == ".." for part in parts):
|
||||
raise ValueError(f"Unsafe archive member path: {member_name}")
|
||||
return parts
|
||||
|
||||
|
||||
def _safe_extract_tar(tar: tarfile.TarFile, target_dir: Path) -> None:
|
||||
"""Extract a tar archive without allowing traversal or link entries."""
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
target_root = target_dir.resolve()
|
||||
|
||||
for member in tar.getmembers():
|
||||
parts = _normalize_tar_member_parts(member.name)
|
||||
target = target_dir.joinpath(*parts)
|
||||
target_real = target.resolve(strict=False)
|
||||
|
||||
try:
|
||||
target_real.relative_to(target_root)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Unsafe archive member path: {member.name}") from exc
|
||||
|
||||
if member.isdir():
|
||||
target_real.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
|
||||
if not member.isfile():
|
||||
raise ValueError(f"Unsupported archive member type: {member.name}")
|
||||
|
||||
target_real.parent.mkdir(parents=True, exist_ok=True)
|
||||
extracted = tar.extractfile(member)
|
||||
if extracted is None:
|
||||
raise ValueError(f"Cannot read archive member: {member.name}")
|
||||
|
||||
with extracted, open(target_real, "wb") as dst:
|
||||
shutil.copyfileobj(extracted, dst)
|
||||
|
||||
try:
|
||||
os.chmod(target_real, member.mode & 0o777)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
"""Extract a base64-encoded tar.gz archive into target_dir."""
|
||||
if not b64_data:
|
||||
@@ -211,7 +155,7 @@ def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
raw = base64.b64decode(b64_data)
|
||||
buf = io.BytesIO(raw)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
_safe_extract_tar(tar, target_dir)
|
||||
tar.extractall(path=str(target_dir))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -541,7 +485,6 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
@@ -554,7 +497,6 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
|
||||
@@ -549,7 +549,6 @@ class YCBenchEvalEnv(HermesAgentBaseEnv):
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
|
||||
@@ -62,11 +62,6 @@ from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.budget_config import (
|
||||
DEFAULT_RESULT_SIZE_CHARS,
|
||||
DEFAULT_TURN_BUDGET_CHARS,
|
||||
DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
)
|
||||
|
||||
# Import hermes-agent toolset infrastructure
|
||||
from model_tools import get_tool_definitions
|
||||
@@ -165,32 +160,6 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
||||
)
|
||||
|
||||
# --- Tool result budget ---
|
||||
# Defaults imported from tools.budget_config (single source of truth).
|
||||
default_result_size_chars: int = Field(
|
||||
default=DEFAULT_RESULT_SIZE_CHARS,
|
||||
description="Default per-tool threshold (chars) for persisting large results "
|
||||
"to sandbox. Results exceeding this are written to /tmp/hermes-results/ "
|
||||
"and replaced with a preview. Per-tool registry values take precedence "
|
||||
"unless overridden via tool_result_overrides.",
|
||||
)
|
||||
turn_budget_chars: int = Field(
|
||||
default=DEFAULT_TURN_BUDGET_CHARS,
|
||||
description="Aggregate char budget per assistant turn. If all tool results "
|
||||
"in a single turn exceed this, the largest are persisted to disk first.",
|
||||
)
|
||||
preview_size_chars: int = Field(
|
||||
default=DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
description="Size of the inline preview shown after a tool result is persisted.",
|
||||
)
|
||||
tool_result_overrides: Optional[Dict[str, int]] = Field(
|
||||
default=None,
|
||||
description="Per-tool threshold overrides (chars). Keys are tool names, "
|
||||
"values are char thresholds. Overrides both the default and registry "
|
||||
"per-tool values. Example: {'terminal': 10000, 'search_files': 5000}. "
|
||||
"Note: read_file is pinned to infinity and cannot be overridden.",
|
||||
)
|
||||
|
||||
# --- Provider-specific parameters ---
|
||||
# Passed as extra_body to the OpenAI client's chat.completions.create() call.
|
||||
# Useful for OpenRouter provider preferences, transforms, route settings, etc.
|
||||
@@ -207,16 +176,6 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"transforms, and other provider-specific settings.",
|
||||
)
|
||||
|
||||
def build_budget_config(self):
|
||||
"""Build a BudgetConfig from env config fields."""
|
||||
from tools.budget_config import BudgetConfig
|
||||
return BudgetConfig(
|
||||
default_result_size=self.default_result_size_chars,
|
||||
turn_budget=self.turn_budget_chars,
|
||||
preview_size=self.preview_size_chars,
|
||||
tool_overrides=dict(self.tool_result_overrides) if self.tool_result_overrides else {},
|
||||
)
|
||||
|
||||
|
||||
class HermesAgentBaseEnv(BaseEnv):
|
||||
"""
|
||||
@@ -531,7 +490,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
except NotImplementedError:
|
||||
@@ -549,7 +507,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
@@ -563,7 +520,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
|
||||
@@ -472,7 +472,6 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
temperature=0.0, # Deterministic for eval
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
|
||||
@@ -556,18 +556,6 @@ def load_gateway_config() -> GatewayConfig:
|
||||
os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower()
|
||||
if "reactions" in discord_cfg and not os.getenv("DISCORD_REACTIONS"):
|
||||
os.environ["DISCORD_REACTIONS"] = str(discord_cfg["reactions"]).lower()
|
||||
# ignored_channels: channels where bot never responds (even when mentioned)
|
||||
ic = discord_cfg.get("ignored_channels")
|
||||
if ic is not None and not os.getenv("DISCORD_IGNORED_CHANNELS"):
|
||||
if isinstance(ic, list):
|
||||
ic = ",".join(str(v) for v in ic)
|
||||
os.environ["DISCORD_IGNORED_CHANNELS"] = str(ic)
|
||||
# no_thread_channels: channels where bot responds directly without creating thread
|
||||
ntc = discord_cfg.get("no_thread_channels")
|
||||
if ntc is not None and not os.getenv("DISCORD_NO_THREAD_CHANNELS"):
|
||||
if isinstance(ntc, list):
|
||||
ntc = ",".join(str(v) for v in ntc)
|
||||
os.environ["DISCORD_NO_THREAD_CHANNELS"] = str(ntc)
|
||||
|
||||
# Telegram settings → env vars (env vars take precedence)
|
||||
telegram_cfg = yaml_cfg.get("telegram", {})
|
||||
@@ -582,8 +570,6 @@ def load_gateway_config() -> GatewayConfig:
|
||||
if isinstance(frc, list):
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["TELEGRAM_FREE_RESPONSE_CHATS"] = str(frc)
|
||||
if "reactions" in telegram_cfg and not os.getenv("TELEGRAM_REACTIONS"):
|
||||
os.environ["TELEGRAM_REACTIONS"] = str(telegram_cfg["reactions"]).lower()
|
||||
|
||||
whatsapp_cfg = yaml_cfg.get("whatsapp", {})
|
||||
if isinstance(whatsapp_cfg, dict):
|
||||
|
||||
@@ -20,7 +20,6 @@ Requires:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -371,7 +370,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:].strip()
|
||||
if hmac.compare_digest(token, self._api_key):
|
||||
if token == self._api_key:
|
||||
return None # Auth OK
|
||||
|
||||
return web.json_response(
|
||||
@@ -564,10 +563,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
if delta is not None:
|
||||
_stream_q.put(delta)
|
||||
|
||||
def _on_tool_progress(event_type, name, preview, args, **kwargs):
|
||||
def _on_tool_progress(name, preview, args):
|
||||
"""Inject tool progress into the SSE stream for Open WebUI."""
|
||||
if event_type != "tool.started":
|
||||
return # Only show tool start events in chat stream
|
||||
if name.startswith("_"):
|
||||
return # Skip internal events (_thinking)
|
||||
from agent.display import get_tool_emoji
|
||||
@@ -818,29 +815,9 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
return web.json_response(_openai_error("'input' must be a string or array"), status=400)
|
||||
|
||||
# Accept explicit conversation_history from the request body.
|
||||
# This lets stateless clients supply their own history instead of
|
||||
# relying on server-side response chaining via previous_response_id.
|
||||
# Precedence: explicit conversation_history > previous_response_id.
|
||||
# Reconstruct conversation history from previous_response_id
|
||||
conversation_history: List[Dict[str, str]] = []
|
||||
raw_history = body.get("conversation_history")
|
||||
if raw_history:
|
||||
if not isinstance(raw_history, list):
|
||||
return web.json_response(
|
||||
_openai_error("'conversation_history' must be an array of message objects"),
|
||||
status=400,
|
||||
)
|
||||
for i, entry in enumerate(raw_history):
|
||||
if not isinstance(entry, dict) or "role" not in entry or "content" not in entry:
|
||||
return web.json_response(
|
||||
_openai_error(f"conversation_history[{i}] must have 'role' and 'content' fields"),
|
||||
status=400,
|
||||
)
|
||||
conversation_history.append({"role": str(entry["role"]), "content": str(entry["content"])})
|
||||
if previous_response_id:
|
||||
logger.debug("Both conversation_history and previous_response_id provided; using conversation_history")
|
||||
|
||||
if not conversation_history and previous_response_id:
|
||||
if previous_response_id:
|
||||
stored = self._response_store.get(previous_response_id)
|
||||
if stored is None:
|
||||
return web.json_response(_openai_error(f"Previous response not found: {previous_response_id}"), status=404)
|
||||
@@ -1426,49 +1403,14 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
instructions = body.get("instructions")
|
||||
previous_response_id = body.get("previous_response_id")
|
||||
|
||||
# Accept explicit conversation_history from the request body.
|
||||
# Precedence: explicit conversation_history > previous_response_id.
|
||||
conversation_history: List[Dict[str, str]] = []
|
||||
raw_history = body.get("conversation_history")
|
||||
if raw_history:
|
||||
if not isinstance(raw_history, list):
|
||||
return web.json_response(
|
||||
_openai_error("'conversation_history' must be an array of message objects"),
|
||||
status=400,
|
||||
)
|
||||
for i, entry in enumerate(raw_history):
|
||||
if not isinstance(entry, dict) or "role" not in entry or "content" not in entry:
|
||||
return web.json_response(
|
||||
_openai_error(f"conversation_history[{i}] must have 'role' and 'content' fields"),
|
||||
status=400,
|
||||
)
|
||||
conversation_history.append({"role": str(entry["role"]), "content": str(entry["content"])})
|
||||
if previous_response_id:
|
||||
logger.debug("Both conversation_history and previous_response_id provided; using conversation_history")
|
||||
|
||||
if not conversation_history and previous_response_id:
|
||||
if previous_response_id:
|
||||
stored = self._response_store.get(previous_response_id)
|
||||
if stored:
|
||||
conversation_history = list(stored.get("conversation_history", []))
|
||||
if instructions is None:
|
||||
instructions = stored.get("instructions")
|
||||
|
||||
# When input is a multi-message array, extract all but the last
|
||||
# message as conversation history (the last becomes user_message).
|
||||
# Only fires when no explicit history was provided.
|
||||
if not conversation_history and isinstance(raw_input, list) and len(raw_input) > 1:
|
||||
for msg in raw_input[:-1]:
|
||||
if isinstance(msg, dict) and msg.get("role") and msg.get("content"):
|
||||
content = msg["content"]
|
||||
if isinstance(content, list):
|
||||
# Flatten multi-part content blocks to text
|
||||
content = " ".join(
|
||||
part.get("text", "") for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
)
|
||||
conversation_history.append({"role": msg["role"], "content": str(content)})
|
||||
|
||||
session_id = body.get("session_id") or run_id
|
||||
ephemeral_system_prompt = instructions
|
||||
|
||||
|
||||
@@ -124,14 +124,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL targets a private/internal network (SSRF protection).
|
||||
"""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging as _logging
|
||||
@@ -239,14 +232,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached audio file as a string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL targets a private/internal network (SSRF protection).
|
||||
"""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging as _logging
|
||||
@@ -1119,22 +1105,6 @@ class BasePlatformAdapter(ABC):
|
||||
logger.error("[%s] Fallback send also failed: %s", self.name, fallback_result.error)
|
||||
return fallback_result
|
||||
|
||||
@staticmethod
|
||||
def _merge_caption(existing_text: Optional[str], new_text: str) -> str:
|
||||
"""Merge a new caption into existing text, avoiding duplicates.
|
||||
|
||||
Uses line-by-line exact match (not substring) to prevent false positives
|
||||
where a shorter caption is silently dropped because it appears as a
|
||||
substring of a longer one (e.g. "Meeting" inside "Meeting agenda").
|
||||
Whitespace is normalised for comparison.
|
||||
"""
|
||||
if not existing_text:
|
||||
return new_text
|
||||
existing_captions = [c.strip() for c in existing_text.split("\n\n")]
|
||||
if new_text.strip() not in existing_captions:
|
||||
return f"{existing_text}\n\n{new_text}".strip()
|
||||
return existing_text
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
@@ -1194,7 +1164,10 @@ class BasePlatformAdapter(ABC):
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = self._merge_caption(existing.text, event.text)
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
else:
|
||||
self._pending_messages[session_key] = event
|
||||
return # Don't interrupt now - will run after current task completes
|
||||
|
||||
@@ -55,7 +55,6 @@ from gateway.platforms.base import (
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
from tools.url_safety import is_safe_url
|
||||
|
||||
|
||||
def _clean_discord_id(entry: str) -> str:
|
||||
@@ -1286,10 +1285,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not is_safe_url(image_url):
|
||||
logger.warning("[%s] Blocked unsafe image URL during Discord send_image", self.name)
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
@@ -2193,11 +2188,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# UNLESS the channel is in the free-response list or the message is
|
||||
# in a thread where the bot has already participated.
|
||||
#
|
||||
# Config (all settable via discord.* in config.yaml or DISCORD_* env vars):
|
||||
# Config (all settable via discord.* in config.yaml):
|
||||
# discord.require_mention: Require @mention in server channels (default: true)
|
||||
# discord.free_response_channels: Channel IDs where bot responds without mention
|
||||
# discord.ignored_channels: Channel IDs where bot NEVER responds (even when mentioned)
|
||||
# discord.no_thread_channels: Channel IDs where bot responds directly without creating thread
|
||||
# discord.auto_thread: Auto-create thread on @mention in channels (default: true)
|
||||
|
||||
thread_id = None
|
||||
@@ -2208,18 +2201,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
parent_channel_id = self._get_parent_channel_id(message.channel)
|
||||
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
# Check ignored channels first - never respond even when mentioned
|
||||
ignored_channels_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "")
|
||||
ignored_channels = {ch.strip() for ch in ignored_channels_raw.split(",") if ch.strip()}
|
||||
channel_ids = {str(message.channel.id)}
|
||||
if parent_channel_id:
|
||||
channel_ids.add(parent_channel_id)
|
||||
if channel_ids & ignored_channels:
|
||||
logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_ids)
|
||||
return
|
||||
|
||||
free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
|
||||
channel_ids = {str(message.channel.id)}
|
||||
if parent_channel_id:
|
||||
channel_ids.add(parent_channel_id)
|
||||
|
||||
@@ -2241,14 +2225,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Auto-thread: when enabled, automatically create a thread for every
|
||||
# @mention in a text channel so each conversation is isolated (like Slack).
|
||||
# Messages already inside threads or DMs are unaffected.
|
||||
# no_thread_channels: channels where bot responds directly without thread.
|
||||
auto_threaded_channel = None
|
||||
if not is_thread and not isinstance(message.channel, discord.DMChannel):
|
||||
no_thread_channels_raw = os.getenv("DISCORD_NO_THREAD_CHANNELS", "")
|
||||
no_thread_channels = {ch.strip() for ch in no_thread_channels_raw.split(",") if ch.strip()}
|
||||
skip_thread = bool(channel_ids & no_thread_channels)
|
||||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in ("true", "1", "yes")
|
||||
if auto_thread and not skip_thread:
|
||||
if auto_thread:
|
||||
thread = await self._auto_create_thread(message)
|
||||
if thread:
|
||||
is_thread = True
|
||||
|
||||
@@ -2065,7 +2065,10 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = self._merge_caption(existing.text, event.text)
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text.split("\n\n"):
|
||||
existing.text = f"{existing.text}\n\n{event.text}"
|
||||
existing.timestamp = event.timestamp
|
||||
if event.message_id:
|
||||
existing.message_id = event.message_id
|
||||
@@ -2109,10 +2112,6 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
default_ext: str,
|
||||
preferred_name: str,
|
||||
) -> tuple[str, str]:
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(file_url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {file_url[:80]}")
|
||||
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
|
||||
@@ -586,11 +586,6 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Download an image URL and upload it to Matrix."""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(image_url):
|
||||
logger.warning("Matrix: blocked unsafe image URL (SSRF protection)")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
|
||||
|
||||
try:
|
||||
# Try aiohttp first (always available), fall back to httpx
|
||||
try:
|
||||
|
||||
@@ -407,11 +407,6 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
kind: str = "file",
|
||||
) -> SendResult:
|
||||
"""Download a URL and upload it as a file attachment."""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
logger.warning("Mattermost: blocked unsafe URL (SSRF protection)")
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
|
||||
@@ -595,11 +595,6 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(image_url):
|
||||
logger.warning("[Slack] Blocked unsafe image URL (SSRF protection)")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
|
||||
@@ -1632,12 +1632,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(image_url):
|
||||
logger.warning("[%s] Blocked unsafe image URL (SSRF protection)", self.name)
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
|
||||
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs (up to ~5MB)
|
||||
_photo_thread = metadata.get("thread_id") if metadata else None
|
||||
@@ -2227,7 +2222,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = self._merge_caption(existing.text, event.text)
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
|
||||
prior_task = self._pending_photo_batch_tasks.get(batch_key)
|
||||
if prior_task and not prior_task.done():
|
||||
@@ -2417,7 +2415,11 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = self._merge_caption(existing.text, event.text)
|
||||
if existing.text:
|
||||
if event.text not in existing.text.split("\n\n"):
|
||||
existing.text = f"{existing.text}\n\n{event.text}"
|
||||
else:
|
||||
existing.text = event.text
|
||||
|
||||
prior_task = self._media_group_tasks.get(media_group_id)
|
||||
if prior_task:
|
||||
@@ -2673,46 +2675,3 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
auto_skill=topic_skill,
|
||||
timestamp=message.date,
|
||||
)
|
||||
|
||||
# ── Message reactions (processing lifecycle) ──────────────────────────
|
||||
|
||||
def _reactions_enabled(self) -> bool:
|
||||
"""Check if message reactions are enabled via config/env."""
|
||||
return os.getenv("TELEGRAM_REACTIONS", "false").lower() not in ("false", "0", "no")
|
||||
|
||||
async def _set_reaction(self, chat_id: str, message_id: str, emoji: str) -> bool:
|
||||
"""Set a single emoji reaction on a Telegram message."""
|
||||
if not self._bot:
|
||||
return False
|
||||
try:
|
||||
await self._bot.set_message_reaction(
|
||||
chat_id=int(chat_id),
|
||||
message_id=int(message_id),
|
||||
reaction=emoji,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("[%s] set_message_reaction failed (%s): %s", self.name, emoji, e)
|
||||
return False
|
||||
|
||||
async def on_processing_start(self, event: MessageEvent) -> None:
|
||||
"""Add an in-progress reaction when message processing begins."""
|
||||
if not self._reactions_enabled():
|
||||
return
|
||||
chat_id = getattr(event.source, "chat_id", None)
|
||||
message_id = getattr(event, "message_id", None)
|
||||
if chat_id and message_id:
|
||||
await self._set_reaction(chat_id, message_id, "\U0001f440")
|
||||
|
||||
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
|
||||
"""Swap the in-progress reaction for a final success/failure reaction.
|
||||
|
||||
Unlike Discord (additive reactions), Telegram's set_message_reaction
|
||||
replaces all existing reactions in one call — no remove step needed.
|
||||
"""
|
||||
if not self._reactions_enabled():
|
||||
return
|
||||
chat_id = getattr(event.source, "chat_id", None)
|
||||
message_id = getattr(event, "message_id", None)
|
||||
if chat_id and message_id:
|
||||
await self._set_reaction(chat_id, message_id, "\u2705" if success else "\u274c")
|
||||
|
||||
@@ -76,17 +76,8 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
self._routes: Dict[str, dict] = dict(self._static_routes)
|
||||
self._runner = None
|
||||
|
||||
# Delivery info keyed by session chat_id.
|
||||
#
|
||||
# Read by every send() invocation for the chat_id (status messages
|
||||
# AND the final response). Cleaned up via TTL on each POST so the
|
||||
# dict stays bounded — see _prune_delivery_info(). Do NOT pop on
|
||||
# send(), or interim status messages (e.g. fallback notifications,
|
||||
# context-pressure warnings) will consume the entry before the
|
||||
# final response arrives, causing the response to silently fall
|
||||
# back to the "log" deliver type.
|
||||
# Delivery info keyed by session chat_id — consumed by send()
|
||||
self._delivery_info: Dict[str, dict] = {}
|
||||
self._delivery_info_created: Dict[str, float] = {}
|
||||
|
||||
# Reference to gateway runner for cross-platform delivery (set externally)
|
||||
self.gateway_runner = None
|
||||
@@ -169,14 +160,10 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
) -> SendResult:
|
||||
"""Deliver the agent's response to the configured destination.
|
||||
|
||||
chat_id is ``webhook:{route}:{delivery_id}``. The delivery info
|
||||
stored during webhook receipt is read with ``.get()`` (not popped)
|
||||
so that interim status messages emitted before the final response
|
||||
— fallback-model notifications, context-pressure warnings, etc. —
|
||||
do not consume the entry and silently downgrade the final response
|
||||
to the ``log`` deliver type. TTL cleanup happens on POST.
|
||||
chat_id is ``webhook:{route}:{delivery_id}`` — we pop the delivery
|
||||
info stored during webhook receipt so it doesn't leak memory.
|
||||
"""
|
||||
delivery = self._delivery_info.get(chat_id, {})
|
||||
delivery = self._delivery_info.pop(chat_id, {})
|
||||
deliver_type = delivery.get("deliver", "log")
|
||||
|
||||
if deliver_type == "log":
|
||||
@@ -203,23 +190,6 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
success=False, error=f"Unknown deliver type: {deliver_type}"
|
||||
)
|
||||
|
||||
def _prune_delivery_info(self, now: float) -> None:
|
||||
"""Drop delivery_info entries older than the idempotency TTL.
|
||||
|
||||
Mirrors the cleanup pattern used for ``_seen_deliveries``. Called
|
||||
on each POST so the dict size is bounded by ``rate_limit * TTL``
|
||||
even if many webhooks fire and never receive a final response.
|
||||
"""
|
||||
cutoff = now - self._idempotency_ttl
|
||||
stale = [
|
||||
k
|
||||
for k, t in self._delivery_info_created.items()
|
||||
if t < cutoff
|
||||
]
|
||||
for k in stale:
|
||||
self._delivery_info.pop(k, None)
|
||||
self._delivery_info_created.pop(k, None)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
return {"name": chat_id, "type": "webhook"}
|
||||
|
||||
@@ -412,9 +382,7 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
# same route get independent agent runs (not queued/interrupted).
|
||||
session_chat_id = f"webhook:{route_name}:{delivery_id}"
|
||||
|
||||
# Store delivery info for send(). Read by every send() invocation
|
||||
# for this chat_id (interim status messages and the final response),
|
||||
# so we do NOT pop on send. TTL-based cleanup keeps the dict bounded.
|
||||
# Store delivery info for send() — consumed (popped) on delivery
|
||||
deliver_config = {
|
||||
"deliver": route_config.get("deliver", "log"),
|
||||
"deliver_extra": self._render_delivery_extra(
|
||||
@@ -423,8 +391,6 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
"payload": payload,
|
||||
}
|
||||
self._delivery_info[session_chat_id] = deliver_config
|
||||
self._delivery_info_created[session_chat_id] = now
|
||||
self._prune_delivery_info(now)
|
||||
|
||||
# Build source and event
|
||||
source = self.build_source(
|
||||
|
||||
@@ -910,10 +910,6 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
url: str,
|
||||
max_bytes: int,
|
||||
) -> Tuple[bytes, Dict[str, str]]:
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {url[:80]}")
|
||||
|
||||
if not HTTPX_AVAILABLE:
|
||||
raise RuntimeError("httpx is required for WeCom media download")
|
||||
|
||||
|
||||
+15
-33
@@ -1987,7 +1987,10 @@ class GatewayRunner:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
else:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
else:
|
||||
@@ -3342,36 +3345,25 @@ class GatewayRunner:
|
||||
"""Handle /status command."""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
|
||||
|
||||
connected_platforms = [p.value for p in self.adapters.keys()]
|
||||
|
||||
|
||||
# Check if there's an active agent
|
||||
session_key = session_entry.session_key
|
||||
is_running = session_key in self._running_agents
|
||||
|
||||
title = None
|
||||
if self._session_db:
|
||||
try:
|
||||
title = self._session_db.get_session_title(session_entry.session_id)
|
||||
except Exception:
|
||||
title = None
|
||||
|
||||
|
||||
lines = [
|
||||
"📊 **Hermes Gateway Status**",
|
||||
"",
|
||||
f"**Session ID:** `{session_entry.session_id}`",
|
||||
]
|
||||
if title:
|
||||
lines.append(f"**Title:** {title}")
|
||||
lines.extend([
|
||||
f"**Session ID:** `{session_entry.session_id[:12]}...`",
|
||||
f"**Created:** {session_entry.created_at.strftime('%Y-%m-%d %H:%M')}",
|
||||
f"**Last Activity:** {session_entry.updated_at.strftime('%Y-%m-%d %H:%M')}",
|
||||
f"**Tokens:** {session_entry.total_tokens:,}",
|
||||
f"**Agent Running:** {'Yes ⚡' if is_running else 'No'}",
|
||||
"",
|
||||
f"**Connected Platforms:** {', '.join(connected_platforms)}",
|
||||
])
|
||||
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_stop_command(self, event: MessageEvent) -> str:
|
||||
@@ -4921,8 +4913,8 @@ class GatewayRunner:
|
||||
cycle = ["off", "new", "all", "verbose"]
|
||||
descriptions = {
|
||||
"off": "⚙️ Tool progress: **OFF** — no tool activity shown.",
|
||||
"new": "⚙️ Tool progress: **NEW** — shown when tool changes (preview length: `display.tool_preview_length`, default 40).",
|
||||
"all": "⚙️ Tool progress: **ALL** — every tool call shown (preview length: `display.tool_preview_length`, default 40).",
|
||||
"new": "⚙️ Tool progress: **NEW** — shown when tool changes (short previews).",
|
||||
"all": "⚙️ Tool progress: **ALL** — every tool call shown (short previews).",
|
||||
"verbose": "⚙️ Tool progress: **VERBOSE** — every tool call with full arguments.",
|
||||
}
|
||||
|
||||
@@ -6044,11 +6036,6 @@ class GatewayRunner:
|
||||
|
||||
if enriched_parts:
|
||||
prefix = "\n\n".join(enriched_parts)
|
||||
# Strip the empty-content placeholder from the Discord adapter
|
||||
# when we successfully transcribed the audio — it's redundant.
|
||||
_placeholder = "(The user sent a message with no text content)"
|
||||
if user_text and user_text.strip() == _placeholder:
|
||||
return prefix
|
||||
if user_text:
|
||||
return f"{prefix}\n\n{user_text}"
|
||||
return prefix
|
||||
@@ -6340,15 +6327,10 @@ class GatewayRunner:
|
||||
progress_queue.put(msg)
|
||||
return
|
||||
|
||||
# "all" / "new" modes: short preview, respects tool_preview_length
|
||||
# config (defaults to 40 chars when unset to keep gateway messages
|
||||
# compact — unlike CLI spinners, these persist as permanent messages).
|
||||
# "all" / "new" modes: short preview, always truncated (40 chars)
|
||||
if preview:
|
||||
from agent.display import get_tool_preview_max_len
|
||||
_pl = get_tool_preview_max_len()
|
||||
_cap = _pl if _pl > 0 else 40
|
||||
if len(preview) > _cap:
|
||||
preview = preview[:_cap - 3] + "..."
|
||||
if len(preview) > 40:
|
||||
preview = preview[:37] + "..."
|
||||
msg = f"{emoji} {tool_name}: \"{preview}\""
|
||||
else:
|
||||
msg = f"{emoji} {tool_name}..."
|
||||
|
||||
+15
-4
@@ -37,7 +37,7 @@ from typing import Any, Dict, List, Optional
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from hermes_cli.config import get_hermes_home, get_config_path, read_raw_config
|
||||
from hermes_cli.config import get_hermes_home, get_config_path
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -2214,7 +2214,14 @@ def _update_config_for_provider(
|
||||
config_path = get_config_path()
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config = read_raw_config()
|
||||
config: Dict[str, Any] = {}
|
||||
if config_path.exists():
|
||||
try:
|
||||
loaded = yaml.safe_load(config_path.read_text()) or {}
|
||||
if isinstance(loaded, dict):
|
||||
config = loaded
|
||||
except Exception:
|
||||
config = {}
|
||||
|
||||
current_model = config.get("model")
|
||||
if isinstance(current_model, dict):
|
||||
@@ -2251,8 +2258,12 @@ def _reset_config_provider() -> Path:
|
||||
if not config_path.exists():
|
||||
return config_path
|
||||
|
||||
config = read_raw_config()
|
||||
if not config:
|
||||
try:
|
||||
config = yaml.safe_load(config_path.read_text()) or {}
|
||||
except Exception:
|
||||
return config_path
|
||||
|
||||
if not isinstance(config, dict):
|
||||
return config_path
|
||||
|
||||
model = config.get("model")
|
||||
|
||||
+1
-75
@@ -5,7 +5,6 @@ Pure display functions with no HermesCLI state dependency.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
@@ -190,79 +189,6 @@ def check_for_updates() -> Optional[int]:
|
||||
return behind
|
||||
|
||||
|
||||
def _resolve_repo_dir() -> Optional[Path]:
|
||||
"""Return the active Hermes git checkout, or None if this isn't a git install."""
|
||||
hermes_home = get_hermes_home()
|
||||
repo_dir = hermes_home / "hermes-agent"
|
||||
if not (repo_dir / ".git").exists():
|
||||
repo_dir = Path(__file__).parent.parent.resolve()
|
||||
return repo_dir if (repo_dir / ".git").exists() else None
|
||||
|
||||
|
||||
def _git_short_hash(repo_dir: Path, rev: str) -> Optional[str]:
|
||||
"""Resolve a git revision to an 8-character short hash."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--short=8", rev],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
value = (result.stdout or "").strip()
|
||||
return value or None
|
||||
|
||||
|
||||
def get_git_banner_state(repo_dir: Optional[Path] = None) -> Optional[dict]:
|
||||
"""Return upstream/local git hashes for the startup banner."""
|
||||
repo_dir = repo_dir or _resolve_repo_dir()
|
||||
if repo_dir is None:
|
||||
return None
|
||||
|
||||
upstream = _git_short_hash(repo_dir, "origin/main")
|
||||
local = _git_short_hash(repo_dir, "HEAD")
|
||||
if not upstream or not local:
|
||||
return None
|
||||
|
||||
ahead = 0
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-list", "--count", "origin/main..HEAD"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
ahead = int((result.stdout or "0").strip() or "0")
|
||||
except Exception:
|
||||
ahead = 0
|
||||
|
||||
return {"upstream": upstream, "local": local, "ahead": max(ahead, 0)}
|
||||
|
||||
|
||||
def format_banner_version_label() -> str:
|
||||
"""Return the version label shown in the startup banner title."""
|
||||
base = f"Hermes Agent v{VERSION} ({RELEASE_DATE})"
|
||||
state = get_git_banner_state()
|
||||
if not state:
|
||||
return base
|
||||
|
||||
upstream = state["upstream"]
|
||||
local = state["local"]
|
||||
ahead = int(state.get("ahead") or 0)
|
||||
|
||||
if ahead <= 0 or upstream == local:
|
||||
return f"{base} · upstream {upstream}"
|
||||
|
||||
carried_word = "commit" if ahead == 1 else "commits"
|
||||
return f"{base} · upstream {upstream} · local {local} (+{ahead} carried {carried_word})"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Non-blocking update check
|
||||
# =========================================================================
|
||||
@@ -522,7 +448,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
border_color = _skin_color("banner_border", "#CD7F32")
|
||||
outer_panel = Panel(
|
||||
layout_table,
|
||||
title=f"[bold {title_color}]{format_banner_version_label()}[/]",
|
||||
title=f"[bold {title_color}]{agent_name} v{VERSION} ({RELEASE_DATE})[/]",
|
||||
border_style=border_color,
|
||||
padding=(0, 2),
|
||||
)
|
||||
|
||||
@@ -293,8 +293,14 @@ def _resolve_config_gates() -> set[str]:
|
||||
if not gated:
|
||||
return set()
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
cfg = read_raw_config()
|
||||
import yaml
|
||||
from hermes_constants import get_hermes_home
|
||||
config_path = str(get_hermes_home() / "config.yaml")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
else:
|
||||
cfg = {}
|
||||
except Exception:
|
||||
return set()
|
||||
result: set[str] = set()
|
||||
|
||||
@@ -416,7 +416,6 @@ DEFAULT_CONFIG = {
|
||||
"provider": "local", # "local" (free, faster-whisper) | "groq" | "openai" (Whisper API)
|
||||
"local": {
|
||||
"model": "base", # tiny, base, small, medium, large-v3
|
||||
"language": "", # auto-detect by default; set to "en", "es", "fr", etc. to force
|
||||
},
|
||||
"openai": {
|
||||
"model": "whisper-1", # whisper-1, gpt-4o-mini-transcribe, gpt-4o-transcribe
|
||||
|
||||
+9
-51
@@ -267,34 +267,6 @@ def _profile_suffix() -> str:
|
||||
return hashlib.sha256(str(home).encode()).hexdigest()[:8]
|
||||
|
||||
|
||||
def _profile_arg(hermes_home: str | None = None) -> str:
|
||||
"""Return ``--profile <name>`` only when HERMES_HOME is a named profile.
|
||||
|
||||
For ``~/.hermes/profiles/<name>``, returns ``"--profile <name>"``.
|
||||
For the default profile or hash-based custom paths, returns the empty string.
|
||||
|
||||
Args:
|
||||
hermes_home: Optional explicit HERMES_HOME path. Defaults to the current
|
||||
``get_hermes_home()`` value. Should be passed when generating a
|
||||
service definition for a different user (e.g. system service).
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path as _Path
|
||||
home = Path(hermes_home or str(get_hermes_home())).resolve()
|
||||
default = (_Path.home() / ".hermes").resolve()
|
||||
if home == default:
|
||||
return ""
|
||||
profiles_root = (default / "profiles").resolve()
|
||||
try:
|
||||
rel = home.relative_to(profiles_root)
|
||||
parts = rel.parts
|
||||
if len(parts) == 1 and re.match(r"^[a-z0-9][a-z0-9_-]{0,63}$", parts[0]):
|
||||
return f"--profile {parts[0]}"
|
||||
except ValueError:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def get_service_name() -> str:
|
||||
"""Derive a systemd service name scoped to this HERMES_HOME.
|
||||
|
||||
@@ -654,7 +626,6 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
|
||||
if system:
|
||||
username, group_name, home_dir = _system_service_identity(run_as_user)
|
||||
hermes_home = _hermes_home_for_target_user(home_dir)
|
||||
profile_arg = _profile_arg(hermes_home)
|
||||
path_entries.extend(_build_user_local_paths(Path(home_dir), path_entries))
|
||||
path_entries.extend(common_bin_paths)
|
||||
sane_path = ":".join(path_entries)
|
||||
@@ -669,7 +640,7 @@ StartLimitBurst=5
|
||||
Type=simple
|
||||
User={username}
|
||||
Group={group_name}
|
||||
ExecStart={python_path} -m hermes_cli.main{f" {profile_arg}" if profile_arg else ""} gateway run --replace
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run --replace
|
||||
WorkingDirectory={working_dir}
|
||||
Environment="HOME={home_dir}"
|
||||
Environment="USER={username}"
|
||||
@@ -690,7 +661,6 @@ WantedBy=multi-user.target
|
||||
"""
|
||||
|
||||
hermes_home = str(get_hermes_home().resolve())
|
||||
profile_arg = _profile_arg(hermes_home)
|
||||
path_entries.extend(_build_user_local_paths(Path.home(), path_entries))
|
||||
path_entries.extend(common_bin_paths)
|
||||
sane_path = ":".join(path_entries)
|
||||
@@ -702,7 +672,7 @@ StartLimitBurst=5
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={python_path} -m hermes_cli.main{f" {profile_arg}" if profile_arg else ""} gateway run --replace
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run --replace
|
||||
WorkingDirectory={working_dir}
|
||||
Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
@@ -995,7 +965,6 @@ def generate_launchd_plist() -> str:
|
||||
log_dir = get_hermes_home() / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
label = get_launchd_label()
|
||||
profile_arg = _profile_arg(hermes_home)
|
||||
# Build a sane PATH for the launchd plist. launchd provides only a
|
||||
# minimal default (/usr/bin:/bin:/usr/sbin:/sbin) which misses Homebrew,
|
||||
# nvm, cargo, etc. We prepend venv/bin and node_modules/.bin (matching
|
||||
@@ -1017,32 +986,21 @@ def generate_launchd_plist() -> str:
|
||||
dict.fromkeys(priority_dirs + [p for p in os.environ.get("PATH", "").split(":") if p])
|
||||
)
|
||||
|
||||
# Build ProgramArguments array, including --profile when using a named profile
|
||||
prog_args = [
|
||||
f"<string>{python_path}</string>",
|
||||
"<string>-m</string>",
|
||||
"<string>hermes_cli.main</string>",
|
||||
]
|
||||
if profile_arg:
|
||||
for part in profile_arg.split():
|
||||
prog_args.append(f"<string>{part}</string>")
|
||||
prog_args.extend([
|
||||
"<string>gateway</string>",
|
||||
"<string>run</string>",
|
||||
"<string>--replace</string>",
|
||||
])
|
||||
prog_args_xml = "\n ".join(prog_args)
|
||||
|
||||
return f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>{label}</string>
|
||||
|
||||
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
{prog_args_xml}
|
||||
<string>{python_path}</string>
|
||||
<string>-m</string>
|
||||
<string>hermes_cli.main</string>
|
||||
<string>gateway</string>
|
||||
<string>run</string>
|
||||
<string>--replace</string>
|
||||
</array>
|
||||
|
||||
<key>WorkingDirectory</key>
|
||||
|
||||
+3
-15
@@ -421,22 +421,10 @@ def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
scroll_offset = 0
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Rows available for list items: rows 2..(max_y-2) inclusive.
|
||||
visible = max(1, max_y - 3)
|
||||
|
||||
# Scroll the viewport so the cursor is always visible.
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible:
|
||||
scroll_offset = cursor - visible + 1
|
||||
scroll_offset = max(0, min(scroll_offset, max(0, len(choices) - visible)))
|
||||
|
||||
try:
|
||||
stdscr.addnstr(
|
||||
0,
|
||||
@@ -448,12 +436,12 @@ def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for row, i in enumerate(range(scroll_offset, min(scroll_offset + visible, len(choices)))):
|
||||
y = row + 2
|
||||
for i, choice in enumerate(choices):
|
||||
y = i + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {choices[i]}"
|
||||
line = f" {arrow} {choice}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
|
||||
@@ -554,7 +554,6 @@ def _get_platform_tools(
|
||||
# MCP servers are expected to be available on all platforms by default.
|
||||
# If the platform explicitly lists one or more MCP server names, treat that
|
||||
# as an allowlist. Otherwise include every globally enabled MCP server.
|
||||
# Special sentinel: "no_mcp" in the toolset list disables all MCP servers.
|
||||
mcp_servers = config.get("mcp_servers") or {}
|
||||
enabled_mcp_servers = {
|
||||
name
|
||||
@@ -562,15 +561,10 @@ def _get_platform_tools(
|
||||
if isinstance(server_cfg, dict)
|
||||
and _parse_enabled_flag(server_cfg.get("enabled", True), default=True)
|
||||
}
|
||||
# Allow "no_mcp" sentinel to opt out of all MCP servers for this platform
|
||||
if "no_mcp" in toolset_names:
|
||||
explicit_mcp_servers = set()
|
||||
enabled_toolsets.update(explicit_passthrough - enabled_mcp_servers - {"no_mcp"})
|
||||
else:
|
||||
explicit_mcp_servers = explicit_passthrough & enabled_mcp_servers
|
||||
enabled_toolsets.update(explicit_passthrough - enabled_mcp_servers)
|
||||
explicit_mcp_servers = explicit_passthrough & enabled_mcp_servers
|
||||
enabled_toolsets.update(explicit_passthrough - enabled_mcp_servers)
|
||||
if include_default_mcp_servers:
|
||||
if explicit_mcp_servers or "no_mcp" in toolset_names:
|
||||
if explicit_mcp_servers:
|
||||
enabled_toolsets.update(explicit_mcp_servers)
|
||||
else:
|
||||
enabled_toolsets.update(enabled_mcp_servers)
|
||||
|
||||
@@ -17,7 +17,7 @@ Or manually:
|
||||
|
||||
```bash
|
||||
hermes config set memory.provider supermemory
|
||||
echo 'SUPERMEMORY_API_KEY=***' >> ~/.hermes/.env
|
||||
echo 'SUPERMEMORY_API_KEY=your-key-here' >> ~/.hermes/.env
|
||||
```
|
||||
|
||||
## Config
|
||||
@@ -26,23 +26,15 @@ Config file: `$HERMES_HOME/supermemory.json`
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `container_tag` | `hermes` | Container tag used for search and writes. Supports `{identity}` template for profile-scoped tags (e.g. `hermes-{identity}` → `hermes-coder`). |
|
||||
| `container_tag` | `hermes` | Container tag used for search and writes |
|
||||
| `auto_recall` | `true` | Inject relevant memory context before turns |
|
||||
| `auto_capture` | `true` | Store cleaned user-assistant turns after each response |
|
||||
| `max_recall_results` | `10` | Max recalled items to format into context |
|
||||
| `profile_frequency` | `50` | Include profile facts on first turn and every N turns |
|
||||
| `capture_mode` | `all` | Skip tiny or trivial turns by default |
|
||||
| `search_mode` | `hybrid` | Search mode: `hybrid` (profile + memories), `memories` (memories only), `documents` (documents only) |
|
||||
| `entity_context` | built-in default | Extraction guidance passed to Supermemory |
|
||||
| `api_timeout` | `5.0` | Timeout for SDK and ingest requests |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `SUPERMEMORY_API_KEY` | API key (required) |
|
||||
| `SUPERMEMORY_CONTAINER_TAG` | Override container tag (takes priority over config file) |
|
||||
|
||||
## Tools
|
||||
|
||||
| Tool | Description |
|
||||
@@ -60,40 +52,3 @@ When enabled, Hermes can:
|
||||
- store cleaned conversation turns after each completed response
|
||||
- ingest the full session on session end for richer graph updates
|
||||
- expose explicit tools for search, store, forget, and profile access
|
||||
|
||||
## Profile-Scoped Containers
|
||||
|
||||
Use `{identity}` in the `container_tag` to scope memories per Hermes profile:
|
||||
|
||||
```json
|
||||
{
|
||||
"container_tag": "hermes-{identity}"
|
||||
}
|
||||
```
|
||||
|
||||
For a profile named `coder`, this resolves to `hermes-coder`. The default profile resolves to `hermes-default`. Without `{identity}`, all profiles share the same container.
|
||||
|
||||
## Multi-Container Mode
|
||||
|
||||
For advanced setups (e.g. OpenClaw-style multi-workspace), you can enable custom container tags so the agent can read/write across multiple named containers:
|
||||
|
||||
```json
|
||||
{
|
||||
"container_tag": "hermes",
|
||||
"enable_custom_container_tags": true,
|
||||
"custom_containers": ["project-alpha", "project-beta", "shared-knowledge"],
|
||||
"custom_container_instructions": "Use project-alpha for coding tasks, project-beta for research, and shared-knowledge for team-wide facts."
|
||||
}
|
||||
```
|
||||
|
||||
When enabled:
|
||||
- `supermemory_search`, `supermemory_store`, `supermemory_forget`, and `supermemory_profile` accept an optional `container_tag` parameter
|
||||
- The tag must be in the whitelist: primary container + `custom_containers`
|
||||
- Automatic operations (turn sync, prefetch, memory write mirroring, session ingest) always use the **primary** container only
|
||||
- Custom container instructions are injected into the system prompt
|
||||
|
||||
## Support
|
||||
|
||||
- [Supermemory Discord](https://supermemory.link/discord)
|
||||
- [support@supermemory.com](mailto:support@supermemory.com)
|
||||
- [supermemory.ai](https://supermemory.ai)
|
||||
|
||||
@@ -26,8 +26,6 @@ _DEFAULT_CONTAINER_TAG = "hermes"
|
||||
_DEFAULT_MAX_RECALL_RESULTS = 10
|
||||
_DEFAULT_PROFILE_FREQUENCY = 50
|
||||
_DEFAULT_CAPTURE_MODE = "all"
|
||||
_DEFAULT_SEARCH_MODE = "hybrid"
|
||||
_VALID_SEARCH_MODES = ("hybrid", "memories", "documents")
|
||||
_DEFAULT_API_TIMEOUT = 5.0
|
||||
_MIN_CAPTURE_LENGTH = 10
|
||||
_MAX_ENTITY_CONTEXT_LENGTH = 1500
|
||||
@@ -61,12 +59,8 @@ def _default_config() -> dict:
|
||||
"max_recall_results": _DEFAULT_MAX_RECALL_RESULTS,
|
||||
"profile_frequency": _DEFAULT_PROFILE_FREQUENCY,
|
||||
"capture_mode": _DEFAULT_CAPTURE_MODE,
|
||||
"search_mode": _DEFAULT_SEARCH_MODE,
|
||||
"entity_context": _DEFAULT_ENTITY_CONTEXT,
|
||||
"api_timeout": _DEFAULT_API_TIMEOUT,
|
||||
"enable_custom_container_tags": False,
|
||||
"custom_containers": [],
|
||||
"custom_container_instructions": "",
|
||||
}
|
||||
|
||||
|
||||
@@ -106,10 +100,7 @@ def _load_supermemory_config(hermes_home: str) -> dict:
|
||||
except Exception:
|
||||
logger.debug("Failed to parse %s", config_path, exc_info=True)
|
||||
|
||||
# Keep raw container_tag — template variables like {identity} are resolved
|
||||
# in initialize(), and _sanitize_tag runs AFTER resolution.
|
||||
raw_tag = str(config.get("container_tag", _DEFAULT_CONTAINER_TAG)).strip()
|
||||
config["container_tag"] = raw_tag if raw_tag else _DEFAULT_CONTAINER_TAG
|
||||
config["container_tag"] = _sanitize_tag(str(config.get("container_tag", _DEFAULT_CONTAINER_TAG)))
|
||||
config["auto_recall"] = _as_bool(config.get("auto_recall"), True)
|
||||
config["auto_capture"] = _as_bool(config.get("auto_capture"), True)
|
||||
try:
|
||||
@@ -121,23 +112,11 @@ def _load_supermemory_config(hermes_home: str) -> dict:
|
||||
except Exception:
|
||||
config["profile_frequency"] = _DEFAULT_PROFILE_FREQUENCY
|
||||
config["capture_mode"] = "everything" if config.get("capture_mode") == "everything" else "all"
|
||||
raw_search_mode = str(config.get("search_mode", _DEFAULT_SEARCH_MODE)).strip().lower()
|
||||
config["search_mode"] = raw_search_mode if raw_search_mode in _VALID_SEARCH_MODES else _DEFAULT_SEARCH_MODE
|
||||
config["entity_context"] = _clamp_entity_context(str(config.get("entity_context", _DEFAULT_ENTITY_CONTEXT)))
|
||||
try:
|
||||
config["api_timeout"] = max(0.5, min(15.0, float(config.get("api_timeout", _DEFAULT_API_TIMEOUT))))
|
||||
except Exception:
|
||||
config["api_timeout"] = _DEFAULT_API_TIMEOUT
|
||||
|
||||
# Multi-container support
|
||||
config["enable_custom_container_tags"] = _as_bool(config.get("enable_custom_container_tags"), False)
|
||||
raw_containers = config.get("custom_containers", [])
|
||||
if isinstance(raw_containers, list):
|
||||
config["custom_containers"] = [_sanitize_tag(str(t)) for t in raw_containers if t]
|
||||
else:
|
||||
config["custom_containers"] = []
|
||||
config["custom_container_instructions"] = str(config.get("custom_container_instructions", "")).strip()
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@@ -261,41 +240,28 @@ def _is_trivial_message(text: str) -> bool:
|
||||
|
||||
|
||||
class _SupermemoryClient:
|
||||
def __init__(self, api_key: str, timeout: float, container_tag: str, search_mode: str = "hybrid"):
|
||||
def __init__(self, api_key: str, timeout: float, container_tag: str):
|
||||
from supermemory import Supermemory
|
||||
|
||||
self._api_key = api_key
|
||||
self._container_tag = container_tag
|
||||
self._search_mode = search_mode if search_mode in _VALID_SEARCH_MODES else _DEFAULT_SEARCH_MODE
|
||||
self._timeout = timeout
|
||||
self._client = Supermemory(api_key=api_key, timeout=timeout, max_retries=0)
|
||||
|
||||
def add_memory(self, content: str, metadata: Optional[dict] = None, *,
|
||||
entity_context: str = "", container_tag: Optional[str] = None,
|
||||
custom_id: Optional[str] = None) -> dict:
|
||||
tag = container_tag or self._container_tag
|
||||
kwargs: dict[str, Any] = {
|
||||
def add_memory(self, content: str, metadata: Optional[dict] = None, *, entity_context: str = "") -> dict:
|
||||
kwargs = {
|
||||
"content": content.strip(),
|
||||
"container_tags": [tag],
|
||||
"container_tags": [self._container_tag],
|
||||
}
|
||||
if metadata:
|
||||
kwargs["metadata"] = metadata
|
||||
if entity_context:
|
||||
kwargs["entity_context"] = _clamp_entity_context(entity_context)
|
||||
if custom_id:
|
||||
kwargs["custom_id"] = custom_id
|
||||
result = self._client.documents.add(**kwargs)
|
||||
return {"id": getattr(result, "id", "")}
|
||||
|
||||
def search_memories(self, query: str, *, limit: int = 5,
|
||||
container_tag: Optional[str] = None,
|
||||
search_mode: Optional[str] = None) -> list[dict]:
|
||||
tag = container_tag or self._container_tag
|
||||
mode = search_mode or self._search_mode
|
||||
kwargs: dict[str, Any] = {"q": query, "container_tag": tag, "limit": limit}
|
||||
if mode in _VALID_SEARCH_MODES:
|
||||
kwargs["search_mode"] = mode
|
||||
response = self._client.search.memories(**kwargs)
|
||||
def search_memories(self, query: str, *, limit: int = 5) -> list[dict]:
|
||||
response = self._client.search.memories(q=query, container_tag=self._container_tag, limit=limit)
|
||||
results = []
|
||||
for item in (getattr(response, "results", None) or []):
|
||||
results.append({
|
||||
@@ -307,10 +273,8 @@ class _SupermemoryClient:
|
||||
})
|
||||
return results
|
||||
|
||||
def get_profile(self, query: Optional[str] = None, *,
|
||||
container_tag: Optional[str] = None) -> dict:
|
||||
tag = container_tag or self._container_tag
|
||||
kwargs: dict[str, Any] = {"container_tag": tag}
|
||||
def get_profile(self, query: Optional[str] = None) -> dict:
|
||||
kwargs = {"container_tag": self._container_tag}
|
||||
if query:
|
||||
kwargs["q"] = query
|
||||
response = self._client.profile(**kwargs)
|
||||
@@ -332,19 +296,18 @@ class _SupermemoryClient:
|
||||
})
|
||||
return {"static": static, "dynamic": dynamic, "search_results": search_results}
|
||||
|
||||
def forget_memory(self, memory_id: str, *, container_tag: Optional[str] = None) -> None:
|
||||
tag = container_tag or self._container_tag
|
||||
self._client.memories.forget(container_tag=tag, id=memory_id)
|
||||
def forget_memory(self, memory_id: str) -> None:
|
||||
self._client.memories.forget(container_tag=self._container_tag, id=memory_id)
|
||||
|
||||
def forget_by_query(self, query: str, *, container_tag: Optional[str] = None) -> dict:
|
||||
results = self.search_memories(query, limit=5, container_tag=container_tag)
|
||||
def forget_by_query(self, query: str) -> dict:
|
||||
results = self.search_memories(query, limit=5)
|
||||
if not results:
|
||||
return {"success": False, "message": "No matching memory found to forget."}
|
||||
target = results[0]
|
||||
memory_id = target.get("id", "")
|
||||
if not memory_id:
|
||||
return {"success": False, "message": "Best matching memory has no id."}
|
||||
self.forget_memory(memory_id, container_tag=container_tag)
|
||||
self.forget_memory(memory_id)
|
||||
preview = (target.get("memory") or "")[:100]
|
||||
return {"success": True, "message": f'Forgot: "{preview}"', "id": memory_id}
|
||||
|
||||
@@ -435,17 +398,11 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
self._max_recall_results = _DEFAULT_MAX_RECALL_RESULTS
|
||||
self._profile_frequency = _DEFAULT_PROFILE_FREQUENCY
|
||||
self._capture_mode = _DEFAULT_CAPTURE_MODE
|
||||
self._search_mode = _DEFAULT_SEARCH_MODE
|
||||
self._entity_context = _DEFAULT_ENTITY_CONTEXT
|
||||
self._api_timeout = _DEFAULT_API_TIMEOUT
|
||||
self._hermes_home = ""
|
||||
self._write_enabled = True
|
||||
self._active = False
|
||||
# Multi-container support
|
||||
self._enable_custom_containers = False
|
||||
self._custom_containers: List[str] = []
|
||||
self._custom_container_instructions = ""
|
||||
self._allowed_containers: List[str] = []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -462,11 +419,16 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
return False
|
||||
|
||||
def get_config_schema(self):
|
||||
# Only prompt for the API key during `hermes memory setup`.
|
||||
# All other options are documented for $HERMES_HOME/supermemory.json
|
||||
# or the SUPERMEMORY_CONTAINER_TAG env var.
|
||||
return [
|
||||
{"key": "api_key", "description": "Supermemory API key", "secret": True, "required": True, "env_var": "SUPERMEMORY_API_KEY", "url": "https://supermemory.ai"},
|
||||
{"key": "container_tag", "description": "Container tag for reads and writes", "default": _DEFAULT_CONTAINER_TAG},
|
||||
{"key": "auto_recall", "description": "Enable automatic recall before each turn", "default": "true", "choices": ["true", "false"]},
|
||||
{"key": "auto_capture", "description": "Enable automatic capture after each completed turn", "default": "true", "choices": ["true", "false"]},
|
||||
{"key": "max_recall_results", "description": "Maximum recalled items to inject", "default": str(_DEFAULT_MAX_RECALL_RESULTS)},
|
||||
{"key": "profile_frequency", "description": "Include profile facts on first turn and every N turns", "default": str(_DEFAULT_PROFILE_FREQUENCY)},
|
||||
{"key": "capture_mode", "description": "Capture mode", "default": _DEFAULT_CAPTURE_MODE, "choices": ["all", "everything"]},
|
||||
{"key": "entity_context", "description": "Extraction guidance passed to Supermemory", "default": _DEFAULT_ENTITY_CONTEXT},
|
||||
{"key": "api_timeout", "description": "Timeout in seconds for SDK and ingest calls", "default": str(_DEFAULT_API_TIMEOUT)},
|
||||
]
|
||||
|
||||
def save_config(self, values, hermes_home):
|
||||
@@ -484,29 +446,14 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
self._turn_count = 0
|
||||
self._config = _load_supermemory_config(self._hermes_home)
|
||||
self._api_key = os.environ.get("SUPERMEMORY_API_KEY", "")
|
||||
|
||||
# Resolve container tag: env var > config > default.
|
||||
# Supports {identity} template for profile-scoped containers.
|
||||
env_tag = os.environ.get("SUPERMEMORY_CONTAINER_TAG", "").strip()
|
||||
raw_tag = env_tag or self._config["container_tag"]
|
||||
identity = kwargs.get("agent_identity", "default")
|
||||
self._container_tag = _sanitize_tag(raw_tag.replace("{identity}", identity))
|
||||
|
||||
self._container_tag = self._config["container_tag"]
|
||||
self._auto_recall = self._config["auto_recall"]
|
||||
self._auto_capture = self._config["auto_capture"]
|
||||
self._max_recall_results = self._config["max_recall_results"]
|
||||
self._profile_frequency = self._config["profile_frequency"]
|
||||
self._capture_mode = self._config["capture_mode"]
|
||||
self._search_mode = self._config["search_mode"]
|
||||
self._entity_context = self._config["entity_context"]
|
||||
self._api_timeout = self._config["api_timeout"]
|
||||
|
||||
# Multi-container setup
|
||||
self._enable_custom_containers = self._config["enable_custom_container_tags"]
|
||||
self._custom_containers = self._config["custom_containers"]
|
||||
self._custom_container_instructions = self._config["custom_container_instructions"]
|
||||
self._allowed_containers = [self._container_tag] + list(self._custom_containers)
|
||||
|
||||
agent_context = kwargs.get("agent_context", "")
|
||||
self._write_enabled = agent_context not in ("cron", "flush", "subagent")
|
||||
self._active = bool(self._api_key)
|
||||
@@ -517,7 +464,6 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
api_key=self._api_key,
|
||||
timeout=self._api_timeout,
|
||||
container_tag=self._container_tag,
|
||||
search_mode=self._search_mode,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Supermemory initialization failed", exc_info=True)
|
||||
@@ -530,18 +476,11 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
def system_prompt_block(self) -> str:
|
||||
if not self._active:
|
||||
return ""
|
||||
lines = [
|
||||
"# Supermemory",
|
||||
f"Active. Container: {self._container_tag}.",
|
||||
"Use supermemory_search, supermemory_store, supermemory_forget, and supermemory_profile for explicit memory operations.",
|
||||
]
|
||||
if self._enable_custom_containers and self._custom_containers:
|
||||
tags_str = ", ".join(self._allowed_containers)
|
||||
lines.append(f"\nMulti-container mode enabled. Available containers: {tags_str}.")
|
||||
lines.append("Pass an optional container_tag to supermemory_search, supermemory_store, supermemory_forget, and supermemory_profile to target a specific container.")
|
||||
if self._custom_container_instructions:
|
||||
lines.append(f"\n{self._custom_container_instructions}")
|
||||
return "\n".join(lines)
|
||||
return (
|
||||
"# Supermemory\n"
|
||||
f"Active. Container: {self._container_tag}.\n"
|
||||
"Use supermemory_search, supermemory_store, supermemory_forget, and supermemory_profile for explicit memory operations."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
if not self._active or not self._auto_recall or not self._client or not query.strip():
|
||||
@@ -643,62 +582,22 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
thread.join(timeout=5.0)
|
||||
setattr(self, attr_name, None)
|
||||
|
||||
def _resolve_tool_container_tag(self, args: dict) -> Optional[str]:
|
||||
"""Validate and resolve container_tag from tool call args.
|
||||
|
||||
Returns None (use primary) if multi-container is disabled or no tag provided.
|
||||
Returns the validated tag if it's in the allowed list.
|
||||
Raises ValueError if the tag is not whitelisted.
|
||||
"""
|
||||
if not self._enable_custom_containers:
|
||||
return None
|
||||
tag = str(args.get("container_tag") or "").strip()
|
||||
if not tag:
|
||||
return None
|
||||
sanitized = _sanitize_tag(tag)
|
||||
if sanitized not in self._allowed_containers:
|
||||
raise ValueError(
|
||||
f"Container tag '{sanitized}' is not allowed. "
|
||||
f"Allowed: {', '.join(self._allowed_containers)}"
|
||||
)
|
||||
return sanitized
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
if not self._enable_custom_containers:
|
||||
return [STORE_SCHEMA, SEARCH_SCHEMA, FORGET_SCHEMA, PROFILE_SCHEMA]
|
||||
|
||||
# When multi-container is enabled, add optional container_tag to relevant tools
|
||||
container_param = {
|
||||
"type": "string",
|
||||
"description": f"Optional container tag. Allowed: {', '.join(self._allowed_containers)}. Defaults to primary ({self._container_tag}).",
|
||||
}
|
||||
schemas = []
|
||||
for base in [STORE_SCHEMA, SEARCH_SCHEMA, FORGET_SCHEMA, PROFILE_SCHEMA]:
|
||||
schema = json.loads(json.dumps(base)) # deep copy
|
||||
schema["parameters"]["properties"]["container_tag"] = container_param
|
||||
schemas.append(schema)
|
||||
return schemas
|
||||
return [STORE_SCHEMA, SEARCH_SCHEMA, FORGET_SCHEMA, PROFILE_SCHEMA]
|
||||
|
||||
def _tool_store(self, args: dict) -> str:
|
||||
content = str(args.get("content") or "").strip()
|
||||
if not content:
|
||||
return tool_error("content is required")
|
||||
try:
|
||||
tag = self._resolve_tool_container_tag(args)
|
||||
except ValueError as exc:
|
||||
return tool_error(str(exc))
|
||||
metadata = args.get("metadata") or {}
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
metadata.setdefault("type", _detect_category(content))
|
||||
metadata["source"] = "hermes_tool"
|
||||
try:
|
||||
result = self._client.add_memory(content, metadata=metadata, entity_context=self._entity_context, container_tag=tag)
|
||||
result = self._client.add_memory(content, metadata=metadata, entity_context=self._entity_context)
|
||||
preview = content[:80] + ("..." if len(content) > 80 else "")
|
||||
resp: dict[str, Any] = {"saved": True, "id": result.get("id", ""), "preview": preview}
|
||||
if tag:
|
||||
resp["container_tag"] = tag
|
||||
return json.dumps(resp)
|
||||
return json.dumps({"saved": True, "id": result.get("id", ""), "preview": preview})
|
||||
except Exception as exc:
|
||||
return tool_error(f"Failed to store memory: {exc}")
|
||||
|
||||
@@ -706,29 +605,22 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
query = str(args.get("query") or "").strip()
|
||||
if not query:
|
||||
return tool_error("query is required")
|
||||
try:
|
||||
tag = self._resolve_tool_container_tag(args)
|
||||
except ValueError as exc:
|
||||
return tool_error(str(exc))
|
||||
try:
|
||||
limit = max(1, min(20, int(args.get("limit", 5) or 5)))
|
||||
except Exception:
|
||||
limit = 5
|
||||
try:
|
||||
results = self._client.search_memories(query, limit=limit, container_tag=tag)
|
||||
results = self._client.search_memories(query, limit=limit)
|
||||
formatted = []
|
||||
for item in results:
|
||||
entry: dict[str, Any] = {"id": item.get("id", ""), "content": item.get("memory", "")}
|
||||
entry = {"id": item.get("id", ""), "content": item.get("memory", "")}
|
||||
if item.get("similarity") is not None:
|
||||
try:
|
||||
entry["similarity"] = round(float(item["similarity"]) * 100)
|
||||
except Exception:
|
||||
pass
|
||||
formatted.append(entry)
|
||||
resp: dict[str, Any] = {"results": formatted, "count": len(formatted)}
|
||||
if tag:
|
||||
resp["container_tag"] = tag
|
||||
return json.dumps(resp)
|
||||
return json.dumps({"results": formatted, "count": len(formatted)})
|
||||
except Exception as exc:
|
||||
return tool_error(f"Search failed: {exc}")
|
||||
|
||||
@@ -737,39 +629,28 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
query = str(args.get("query") or "").strip()
|
||||
if not memory_id and not query:
|
||||
return tool_error("Provide either id or query")
|
||||
try:
|
||||
tag = self._resolve_tool_container_tag(args)
|
||||
except ValueError as exc:
|
||||
return tool_error(str(exc))
|
||||
try:
|
||||
if memory_id:
|
||||
self._client.forget_memory(memory_id, container_tag=tag)
|
||||
self._client.forget_memory(memory_id)
|
||||
return json.dumps({"forgotten": True, "id": memory_id})
|
||||
return json.dumps(self._client.forget_by_query(query, container_tag=tag))
|
||||
return json.dumps(self._client.forget_by_query(query))
|
||||
except Exception as exc:
|
||||
return tool_error(f"Forget failed: {exc}")
|
||||
|
||||
def _tool_profile(self, args: dict) -> str:
|
||||
query = str(args.get("query") or "").strip() or None
|
||||
try:
|
||||
tag = self._resolve_tool_container_tag(args)
|
||||
except ValueError as exc:
|
||||
return tool_error(str(exc))
|
||||
try:
|
||||
profile = self._client.get_profile(query=query, container_tag=tag)
|
||||
profile = self._client.get_profile(query=query)
|
||||
sections = []
|
||||
if profile["static"]:
|
||||
sections.append("## User Profile (Persistent)\n" + "\n".join(f"- {item}" for item in profile["static"]))
|
||||
if profile["dynamic"]:
|
||||
sections.append("## Recent Context\n" + "\n".join(f"- {item}" for item in profile["dynamic"]))
|
||||
resp: dict[str, Any] = {
|
||||
return json.dumps({
|
||||
"profile": "\n\n".join(sections),
|
||||
"static_count": len(profile["static"]),
|
||||
"dynamic_count": len(profile["dynamic"]),
|
||||
}
|
||||
if tag:
|
||||
resp["container_tag"] = tag
|
||||
return json.dumps(resp)
|
||||
})
|
||||
except Exception as exc:
|
||||
return tool_error(f"Profile failed: {exc}")
|
||||
|
||||
|
||||
+78
-49
@@ -66,8 +66,7 @@ from model_tools import (
|
||||
handle_function_call,
|
||||
check_toolset_requirements,
|
||||
)
|
||||
from tools.terminal_tool import cleanup_vm, get_active_env
|
||||
from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.interrupt import set_interrupt as _set_interrupt
|
||||
from tools.browser_tool import cleanup_browser
|
||||
|
||||
@@ -410,6 +409,63 @@ def _strip_budget_warnings_from_history(messages: list) -> None:
|
||||
# Large tool result handler — save oversized output to temp file
|
||||
# =========================================================================
|
||||
|
||||
# Threshold at which tool results are saved to a file instead of kept inline.
|
||||
# 100K chars ≈ 25K tokens — generous for any reasonable output but prevents
|
||||
# catastrophic context explosions.
|
||||
_LARGE_RESULT_CHARS = 100_000
|
||||
|
||||
# How many characters of the original result to include as an inline preview
|
||||
# so the model has immediate context about what the tool returned.
|
||||
_LARGE_RESULT_PREVIEW_CHARS = 1_500
|
||||
|
||||
|
||||
def _save_oversized_tool_result(function_name: str, function_result: str) -> str:
|
||||
"""Replace oversized tool results with a file reference + preview.
|
||||
|
||||
When a tool returns more than ``_LARGE_RESULT_CHARS`` characters, the full
|
||||
content is written to a temporary file under ``HERMES_HOME/cache/tool_responses/``
|
||||
and the result sent to the model is replaced with:
|
||||
• a brief head preview (first ``_LARGE_RESULT_PREVIEW_CHARS`` chars)
|
||||
• the file path so the model can use ``read_file`` / ``search_files``
|
||||
|
||||
Falls back to destructive truncation if the file write fails.
|
||||
"""
|
||||
original_len = len(function_result)
|
||||
if original_len <= _LARGE_RESULT_CHARS:
|
||||
return function_result
|
||||
|
||||
# Build the target directory
|
||||
try:
|
||||
response_dir = os.path.join(get_hermes_home(), "cache", "tool_responses")
|
||||
os.makedirs(response_dir, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
# Sanitize tool name for use in filename
|
||||
safe_name = re.sub(r"[^\w\-]", "_", function_name)[:40]
|
||||
filename = f"{safe_name}_{timestamp}.txt"
|
||||
filepath = os.path.join(response_dir, filename)
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write(function_result)
|
||||
|
||||
preview = function_result[:_LARGE_RESULT_PREVIEW_CHARS]
|
||||
return (
|
||||
f"{preview}\n\n"
|
||||
f"[Large tool response: {original_len:,} characters total — "
|
||||
f"only the first {_LARGE_RESULT_PREVIEW_CHARS:,} shown above. "
|
||||
f"Full output saved to: {filepath}\n"
|
||||
f"Use read_file or search_files on that path to access the rest.]"
|
||||
)
|
||||
except Exception as exc:
|
||||
# Fall back to destructive truncation if file write fails
|
||||
logger.warning("Failed to save large tool result to file: %s", exc)
|
||||
return (
|
||||
function_result[:_LARGE_RESULT_CHARS]
|
||||
+ f"\n\n[Truncated: tool response was {original_len:,} chars, "
|
||||
f"exceeding the {_LARGE_RESULT_CHARS:,} char limit. "
|
||||
f"File save failed: {exc}]"
|
||||
)
|
||||
|
||||
|
||||
class AIAgent:
|
||||
"""
|
||||
@@ -6168,17 +6224,15 @@ class AIAgent:
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool complete callback error: {cb_err}")
|
||||
|
||||
function_result = maybe_persist_tool_result(
|
||||
content=function_result,
|
||||
tool_name=name,
|
||||
tool_use_id=tc.id,
|
||||
env=get_active_env(effective_task_id),
|
||||
)
|
||||
# Save oversized results to file instead of destructive truncation
|
||||
function_result = _save_oversized_tool_result(name, function_result)
|
||||
|
||||
# Discover subdirectory context files from tool arguments
|
||||
subdir_hints = self._subdirectory_hints.check_tool_call(name, args)
|
||||
if subdir_hints:
|
||||
function_result += subdir_hints
|
||||
|
||||
# Append tool result message in order
|
||||
tool_msg = {
|
||||
"role": "tool",
|
||||
"content": function_result,
|
||||
@@ -6186,12 +6240,6 @@ class AIAgent:
|
||||
}
|
||||
messages.append(tool_msg)
|
||||
|
||||
# ── Per-turn aggregate budget enforcement ─────────────────────────
|
||||
num_tools = len(parsed_calls)
|
||||
if num_tools > 0:
|
||||
turn_tool_msgs = messages[-num_tools:]
|
||||
enforce_turn_budget(turn_tool_msgs, env=get_active_env(effective_task_id))
|
||||
|
||||
# ── Budget pressure injection ────────────────────────────────────
|
||||
budget_warning = self._get_budget_warning(api_call_count)
|
||||
if budget_warning and messages and messages[-1].get("role") == "tool":
|
||||
@@ -6476,12 +6524,8 @@ class AIAgent:
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool complete callback error: {cb_err}")
|
||||
|
||||
function_result = maybe_persist_tool_result(
|
||||
content=function_result,
|
||||
tool_name=function_name,
|
||||
tool_use_id=tool_call.id,
|
||||
env=get_active_env(effective_task_id),
|
||||
)
|
||||
# Save oversized results to file instead of destructive truncation
|
||||
function_result = _save_oversized_tool_result(function_name, function_result)
|
||||
|
||||
# Discover subdirectory context files from tool arguments
|
||||
subdir_hints = self._subdirectory_hints.check_tool_call(function_name, function_args)
|
||||
@@ -6519,11 +6563,6 @@ class AIAgent:
|
||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
||||
time.sleep(self.tool_delay)
|
||||
|
||||
# ── Per-turn aggregate budget enforcement ─────────────────────────
|
||||
num_tools_seq = len(assistant_message.tool_calls)
|
||||
if num_tools_seq > 0:
|
||||
enforce_turn_budget(messages[-num_tools_seq:], env=get_active_env(effective_task_id))
|
||||
|
||||
# ── Budget pressure injection ─────────────────────────────────
|
||||
# After all tool calls in this turn are processed, check if we're
|
||||
# approaching max_iterations. If so, inject a warning into the LAST
|
||||
@@ -7352,30 +7391,20 @@ class AIAgent:
|
||||
response_invalid = True
|
||||
error_details.append("response.output is not a list")
|
||||
elif not output_items:
|
||||
# Stream backfill may have failed, but
|
||||
# _normalize_codex_response can still recover
|
||||
# from response.output_text. Only mark invalid
|
||||
# when that fallback is also absent.
|
||||
_out_text = getattr(response, "output_text", None)
|
||||
_out_text_stripped = _out_text.strip() if isinstance(_out_text, str) else ""
|
||||
if _out_text_stripped:
|
||||
logger.debug(
|
||||
"Codex response.output is empty but output_text is present "
|
||||
"(%d chars); deferring to normalization.",
|
||||
len(_out_text_stripped),
|
||||
)
|
||||
else:
|
||||
_resp_status = getattr(response, "status", None)
|
||||
_resp_incomplete = getattr(response, "incomplete_details", None)
|
||||
logger.warning(
|
||||
"Codex response.output is empty after stream backfill "
|
||||
"(status=%s, incomplete_details=%s, model=%s). %s",
|
||||
_resp_status, _resp_incomplete,
|
||||
getattr(response, "model", None),
|
||||
f"api_mode={self.api_mode} provider={self.provider}",
|
||||
)
|
||||
response_invalid = True
|
||||
error_details.append("response.output is empty")
|
||||
# If we reach here, _run_codex_stream's backfill
|
||||
# from output_item.done events and text-delta
|
||||
# synthesis both failed to populate output.
|
||||
_resp_status = getattr(response, "status", None)
|
||||
_resp_incomplete = getattr(response, "incomplete_details", None)
|
||||
logging.warning(
|
||||
"Codex response.output is empty after stream backfill "
|
||||
"(status=%s, incomplete_details=%s, model=%s). %s",
|
||||
_resp_status, _resp_incomplete,
|
||||
getattr(response, "model", None),
|
||||
f"api_mode={self.api_mode} provider={self.provider}",
|
||||
)
|
||||
response_invalid = True
|
||||
error_details.append("response.output is empty")
|
||||
elif self.api_mode == "anthropic_messages":
|
||||
content_blocks = getattr(response, "content", None) if response is not None else None
|
||||
if response is None:
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
"""Tests for named custom provider and 'main' alias resolution in auxiliary_client."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate(tmp_path, monkeypatch):
|
||||
"""Redirect HERMES_HOME and clear module caches."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# Write a minimal config so load_config doesn't fail
|
||||
(hermes_home / "config.yaml").write_text("model:\n default: test-model\n")
|
||||
|
||||
|
||||
def _write_config(tmp_path, config_dict):
|
||||
"""Write a config.yaml to the test HERMES_HOME."""
|
||||
import yaml
|
||||
config_path = tmp_path / ".hermes" / "config.yaml"
|
||||
config_path.write_text(yaml.dump(config_dict))
|
||||
|
||||
|
||||
class TestNormalizeVisionProvider:
|
||||
"""_normalize_vision_provider should resolve 'main' to actual main provider."""
|
||||
|
||||
def test_main_resolves_to_named_custom(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "my-model", "provider": "custom:beans"},
|
||||
"custom_providers": [{"name": "beans", "base_url": "http://localhost/v1"}],
|
||||
})
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("main") == "custom:beans"
|
||||
|
||||
def test_main_resolves_to_openrouter(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "anthropic/claude-sonnet-4", "provider": "openrouter"},
|
||||
})
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("main") == "openrouter"
|
||||
|
||||
def test_main_resolves_to_deepseek(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "deepseek-chat", "provider": "deepseek"},
|
||||
})
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("main") == "deepseek"
|
||||
|
||||
def test_main_falls_back_to_custom_when_no_provider(self, tmp_path):
|
||||
_write_config(tmp_path, {"model": {"default": "gpt-4o"}})
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("main") == "custom"
|
||||
|
||||
def test_bare_provider_name_unchanged(self):
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("beans") == "beans"
|
||||
assert _normalize_vision_provider("deepseek") == "deepseek"
|
||||
|
||||
def test_codex_alias_still_works(self):
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("codex") == "openai-codex"
|
||||
|
||||
def test_auto_unchanged(self):
|
||||
from agent.auxiliary_client import _normalize_vision_provider
|
||||
assert _normalize_vision_provider("auto") == "auto"
|
||||
assert _normalize_vision_provider(None) == "auto"
|
||||
|
||||
|
||||
class TestResolveProviderClientMainAlias:
|
||||
"""resolve_provider_client('main', ...) should resolve to actual main provider."""
|
||||
|
||||
def test_main_resolves_to_named_custom_provider(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "my-model", "provider": "beans"},
|
||||
"custom_providers": [
|
||||
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, model = resolve_provider_client("main", "override-model")
|
||||
assert client is not None
|
||||
assert model == "override-model"
|
||||
assert "beans.local" in str(client.base_url)
|
||||
|
||||
def test_main_with_custom_colon_prefix(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "my-model", "provider": "custom:beans"},
|
||||
"custom_providers": [
|
||||
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, model = resolve_provider_client("main", "test")
|
||||
assert client is not None
|
||||
assert "beans.local" in str(client.base_url)
|
||||
|
||||
|
||||
class TestResolveProviderClientNamedCustom:
|
||||
"""resolve_provider_client should resolve named custom providers directly."""
|
||||
|
||||
def test_named_custom_provider(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "test-model"},
|
||||
"custom_providers": [
|
||||
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, model = resolve_provider_client("beans", "my-model")
|
||||
assert client is not None
|
||||
assert model == "my-model"
|
||||
assert "beans.local" in str(client.base_url)
|
||||
|
||||
def test_named_custom_provider_default_model(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "main-model"},
|
||||
"custom_providers": [
|
||||
{"name": "beans", "base_url": "http://beans.local/v1", "api_key": "k"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, model = resolve_provider_client("beans")
|
||||
assert client is not None
|
||||
# Should use _read_main_model() fallback
|
||||
assert model == "main-model"
|
||||
|
||||
def test_named_custom_no_api_key_uses_fallback(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "test"},
|
||||
"custom_providers": [
|
||||
{"name": "local", "base_url": "http://localhost:8080/v1"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, model = resolve_provider_client("local", "test")
|
||||
assert client is not None
|
||||
# no-key-required should be used
|
||||
|
||||
def test_nonexistent_named_custom_falls_through(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "test"},
|
||||
"custom_providers": [
|
||||
{"name": "beans", "base_url": "http://beans.local/v1"},
|
||||
],
|
||||
})
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
# "coffee" doesn't exist in custom_providers
|
||||
client, model = resolve_provider_client("coffee", "test")
|
||||
assert client is None
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Security tests for Terminal-Bench 2 archive extraction."""
|
||||
|
||||
import base64
|
||||
import importlib
|
||||
import io
|
||||
import sys
|
||||
import tarfile
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = types.ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
def _load_terminalbench_module(monkeypatch):
|
||||
class _EvalHandlingEnum:
|
||||
STOP_TRAIN = "stop_train"
|
||||
|
||||
class _APIServerConfig:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
class _AgentResult:
|
||||
pass
|
||||
|
||||
class _HermesAgentLoop:
|
||||
pass
|
||||
|
||||
class _HermesAgentBaseEnv:
|
||||
pass
|
||||
|
||||
class _HermesAgentEnvConfig:
|
||||
pass
|
||||
|
||||
class _ToolContext:
|
||||
pass
|
||||
|
||||
stub_modules = {
|
||||
"atroposlib": _stub_module("atroposlib"),
|
||||
"atroposlib.envs": _stub_module("atroposlib.envs"),
|
||||
"atroposlib.envs.base": _stub_module(
|
||||
"atroposlib.envs.base",
|
||||
EvalHandlingEnum=_EvalHandlingEnum,
|
||||
),
|
||||
"atroposlib.envs.server_handling": _stub_module("atroposlib.envs.server_handling"),
|
||||
"atroposlib.envs.server_handling.server_manager": _stub_module(
|
||||
"atroposlib.envs.server_handling.server_manager",
|
||||
APIServerConfig=_APIServerConfig,
|
||||
),
|
||||
"environments.agent_loop": _stub_module(
|
||||
"environments.agent_loop",
|
||||
AgentResult=_AgentResult,
|
||||
HermesAgentLoop=_HermesAgentLoop,
|
||||
),
|
||||
"environments.hermes_base_env": _stub_module(
|
||||
"environments.hermes_base_env",
|
||||
HermesAgentBaseEnv=_HermesAgentBaseEnv,
|
||||
HermesAgentEnvConfig=_HermesAgentEnvConfig,
|
||||
),
|
||||
"environments.tool_context": _stub_module(
|
||||
"environments.tool_context",
|
||||
ToolContext=_ToolContext,
|
||||
),
|
||||
"tools.terminal_tool": _stub_module(
|
||||
"tools.terminal_tool",
|
||||
register_task_env_overrides=lambda *args, **kwargs: None,
|
||||
clear_task_env_overrides=lambda *args, **kwargs: None,
|
||||
cleanup_vm=lambda *args, **kwargs: None,
|
||||
),
|
||||
}
|
||||
|
||||
stub_modules["atroposlib"].envs = stub_modules["atroposlib.envs"]
|
||||
stub_modules["atroposlib.envs"].base = stub_modules["atroposlib.envs.base"]
|
||||
stub_modules["atroposlib.envs"].server_handling = stub_modules["atroposlib.envs.server_handling"]
|
||||
stub_modules["atroposlib.envs.server_handling"].server_manager = stub_modules[
|
||||
"atroposlib.envs.server_handling.server_manager"
|
||||
]
|
||||
|
||||
for name, module in stub_modules.items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
module_name = "environments.benchmarks.terminalbench_2.terminalbench2_env"
|
||||
sys.modules.pop(module_name, None)
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
def _build_tar_b64(entries):
|
||||
buf = io.BytesIO()
|
||||
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||
for entry in entries:
|
||||
kind = entry["kind"]
|
||||
info = tarfile.TarInfo(entry["name"])
|
||||
|
||||
if kind == "dir":
|
||||
info.type = tarfile.DIRTYPE
|
||||
tar.addfile(info)
|
||||
continue
|
||||
|
||||
if kind == "file":
|
||||
data = entry["data"].encode("utf-8")
|
||||
info.size = len(data)
|
||||
tar.addfile(info, io.BytesIO(data))
|
||||
continue
|
||||
|
||||
if kind == "symlink":
|
||||
info.type = tarfile.SYMTYPE
|
||||
info.linkname = entry["target"]
|
||||
tar.addfile(info)
|
||||
continue
|
||||
|
||||
raise ValueError(f"Unknown tar entry kind: {kind}")
|
||||
|
||||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
|
||||
def test_extract_base64_tar_allows_safe_files(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "dir", "name": "nested"},
|
||||
{"kind": "file", "name": "nested/hello.txt", "data": "hello"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert (target / "nested" / "hello.txt").read_text(encoding="utf-8") == "hello"
|
||||
|
||||
|
||||
def test_extract_base64_tar_rejects_path_traversal(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "file", "name": "../escape.txt", "data": "owned"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
with pytest.raises(ValueError, match="Unsafe archive member path"):
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert not (tmp_path / "escape.txt").exists()
|
||||
|
||||
|
||||
def test_extract_base64_tar_rejects_symlinks(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "symlink", "name": "link", "target": "../../escape.txt"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
with pytest.raises(ValueError, match="Unsupported archive member type"):
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert not (target / "link").exists()
|
||||
@@ -439,7 +439,7 @@ class TestChatCompletionsEndpoint:
|
||||
tp_cb = kwargs.get("tool_progress_callback")
|
||||
# Simulate tool progress before streaming content
|
||||
if tp_cb:
|
||||
tp_cb("tool.started", "terminal", "ls -la", {"command": "ls -la"})
|
||||
tp_cb("terminal", "ls -la", {"command": "ls -la"})
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("Here are the files.")
|
||||
@@ -476,8 +476,8 @@ class TestChatCompletionsEndpoint:
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
tp_cb = kwargs.get("tool_progress_callback")
|
||||
if tp_cb:
|
||||
tp_cb("tool.started", "_thinking", "some internal state", {})
|
||||
tp_cb("tool.started", "web_search", "Python docs", {"query": "Python docs"})
|
||||
tp_cb("_thinking", "some internal state", {})
|
||||
tp_cb("web_search", "Python docs", {"query": "Python docs"})
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("Found it.")
|
||||
|
||||
@@ -1,343 +0,0 @@
|
||||
"""Tests for Discord ignored_channels and no_thread_channels config."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class FakeDMChannel:
|
||||
def __init__(self, channel_id: int = 1, name: str = "dm"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
|
||||
|
||||
class FakeTextChannel:
|
||||
def __init__(self, channel_id: int = 1, name: str = "general", guild_name: str = "Hermes Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.guild = SimpleNamespace(name=guild_name)
|
||||
self.topic = None
|
||||
|
||||
|
||||
class FakeThread:
|
||||
def __init__(self, channel_id: int = 1, name: str = "thread", parent=None, guild_name: str = "Hermes Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.parent = parent
|
||||
self.parent_id = getattr(parent, "id", None)
|
||||
self.guild = getattr(parent, "guild", None) or SimpleNamespace(name=guild_name)
|
||||
self.topic = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(monkeypatch):
|
||||
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
|
||||
monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
|
||||
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
adapter = DiscordAdapter(config)
|
||||
adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
|
||||
adapter.handle_message = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def make_message(*, channel, content: str, mentions=None):
|
||||
author = SimpleNamespace(id=42, display_name="TestUser", name="TestUser")
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
content=content,
|
||||
mentions=list(mentions or []),
|
||||
attachments=[],
|
||||
reference=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=channel,
|
||||
author=author,
|
||||
)
|
||||
|
||||
|
||||
# ── ignored_channels ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_channel_blocks_message(adapter, monkeypatch):
|
||||
"""Messages in ignored channels are silently dropped."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=500), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_channel_blocks_even_with_mention(adapter, monkeypatch):
|
||||
"""Ignored channels take priority — even @mentions are dropped."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
|
||||
|
||||
bot_user = adapter._client.user
|
||||
message = make_message(
|
||||
channel=FakeTextChannel(channel_id=500),
|
||||
content=f"<@{bot_user.id}> hello",
|
||||
mentions=[bot_user],
|
||||
)
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_ignored_channel_processes_normally(adapter, monkeypatch):
|
||||
"""Channels not in the ignored list process normally."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500,600")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=700), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_channels_csv_parsing(adapter, monkeypatch):
|
||||
"""Multiple channel IDs are parsed correctly from CSV."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500, 600 , 700")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
for ch_id in (500, 600, 700):
|
||||
adapter.handle_message.reset_mock()
|
||||
message = make_message(channel=FakeTextChannel(channel_id=ch_id), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_channels_empty_string_ignores_nothing(adapter, monkeypatch):
|
||||
"""Empty DISCORD_IGNORED_CHANNELS means nothing is ignored."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=500), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_channel_thread_parent_match(adapter, monkeypatch):
|
||||
"""Thread whose parent channel is ignored should also be ignored."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
parent = FakeTextChannel(channel_id=500, name="ignored-channel")
|
||||
thread = FakeThread(channel_id=501, name="thread-in-ignored", parent=parent)
|
||||
message = make_message(channel=thread, content="hello from thread")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dms_unaffected_by_ignored_channels(adapter, monkeypatch):
|
||||
"""DMs should never be affected by ignored_channels."""
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "500")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeDMChannel(channel_id=500), content="dm hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
# ── no_thread_channels ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_channel_skips_auto_thread(adapter, monkeypatch):
|
||||
"""Channels in no_thread_channels should not auto-create threads."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
adapter._auto_create_thread = AsyncMock(return_value=FakeThread(channel_id=999))
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=800), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_channel_still_auto_threads(adapter, monkeypatch):
|
||||
"""Channels NOT in no_thread_channels still get auto-threading."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
fake_thread = FakeThread(channel_id=999, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=900), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_awaited_once()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_channels_csv_parsing(adapter, monkeypatch):
|
||||
"""Multiple no_thread channel IDs parsed from CSV."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800, 900")
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
adapter._auto_create_thread = AsyncMock(return_value=FakeThread(channel_id=999))
|
||||
|
||||
for ch_id in (800, 900):
|
||||
adapter._auto_create_thread.reset_mock()
|
||||
adapter.handle_message.reset_mock()
|
||||
message = make_message(channel=FakeTextChannel(channel_id=ch_id), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_with_auto_thread_disabled_is_noop(adapter, monkeypatch):
|
||||
"""no_thread_channels is a no-op when auto_thread is globally disabled."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "800")
|
||||
monkeypatch.delenv("DISCORD_IGNORED_CHANNELS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=800), content="hello")
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
# ── config.py bridging ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_config_bridges_ignored_channels(monkeypatch, tmp_path):
|
||||
"""gateway/config.py bridges discord.ignored_channels to env var."""
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"discord": {
|
||||
"ignored_channels": ["111", "222"],
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Use setenv (not delenv) so monkeypatch registers cleanup even when
|
||||
# the var doesn't exist yet — load_gateway_config will overwrite it.
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "")
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
assert os.getenv("DISCORD_IGNORED_CHANNELS") == "111,222"
|
||||
|
||||
|
||||
def test_config_bridges_no_thread_channels(monkeypatch, tmp_path):
|
||||
"""gateway/config.py bridges discord.no_thread_channels to env var."""
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"discord": {
|
||||
"no_thread_channels": ["333"],
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "")
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
assert os.getenv("DISCORD_NO_THREAD_CHANNELS") == "333"
|
||||
|
||||
|
||||
def test_config_env_var_takes_precedence(monkeypatch, tmp_path):
|
||||
"""Env vars should take precedence over config.yaml values."""
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"discord": {
|
||||
"ignored_channels": ["111"],
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("DISCORD_IGNORED_CHANNELS", "999")
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
# Env var should NOT be overwritten
|
||||
assert os.getenv("DISCORD_IGNORED_CHANNELS") == "999"
|
||||
@@ -504,8 +504,7 @@ class TestMattermostFileUpload:
|
||||
self.adapter._session = MagicMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("tools.url_safety.is_safe_url", return_value=True)
|
||||
async def test_send_image_downloads_and_uploads(self, _mock_safe):
|
||||
async def test_send_image_downloads_and_uploads(self):
|
||||
"""send_image should download the URL, upload via /api/v4/files, then post."""
|
||||
# Mock the download (GET)
|
||||
mock_dl_resp = AsyncMock()
|
||||
|
||||
@@ -596,11 +596,10 @@ def _make_aiohttp_resp(status: int, content: bytes = b"file bytes",
|
||||
return resp
|
||||
|
||||
|
||||
@patch("tools.url_safety.is_safe_url", return_value=True)
|
||||
class TestMattermostSendUrlAsFile:
|
||||
"""Tests for MattermostAdapter._send_url_as_file"""
|
||||
|
||||
def test_success_on_first_attempt(self, _mock_safe):
|
||||
def test_success_on_first_attempt(self):
|
||||
"""200 on first attempt → file uploaded and post created."""
|
||||
adapter = _make_mm_adapter()
|
||||
resp = _make_aiohttp_resp(200)
|
||||
@@ -617,7 +616,7 @@ class TestMattermostSendUrlAsFile:
|
||||
adapter._upload_file.assert_called_once()
|
||||
adapter._api_post.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self, _mock_safe):
|
||||
def test_retries_on_429_then_succeeds(self):
|
||||
"""429 on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
@@ -638,7 +637,7 @@ class TestMattermostSendUrlAsFile:
|
||||
assert adapter._session.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_500_then_succeeds(self, _mock_safe):
|
||||
def test_retries_on_500_then_succeeds(self):
|
||||
"""5xx on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
@@ -656,7 +655,7 @@ class TestMattermostSendUrlAsFile:
|
||||
assert result.success
|
||||
assert adapter._session.get.call_count == 2
|
||||
|
||||
def test_falls_back_to_text_after_max_retries_on_5xx(self, _mock_safe):
|
||||
def test_falls_back_to_text_after_max_retries_on_5xx(self):
|
||||
"""Three consecutive 500s exhaust retries; falls back to send() with URL text."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
@@ -675,7 +674,7 @@ class TestMattermostSendUrlAsFile:
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_falls_back_on_client_error(self, _mock_safe):
|
||||
def test_falls_back_on_client_error(self):
|
||||
"""aiohttp.ClientError on every attempt falls back to send() with URL."""
|
||||
import aiohttp
|
||||
|
||||
@@ -700,7 +699,7 @@ class TestMattermostSendUrlAsFile:
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_non_retryable_404_falls_back_immediately(self, _mock_safe):
|
||||
def test_non_retryable_404_falls_back_immediately(self):
|
||||
"""404 is non-retryable (< 500, != 429); send() is called right away."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
|
||||
@@ -71,24 +71,6 @@ class FakeAgent:
|
||||
}
|
||||
|
||||
|
||||
class LongPreviewAgent:
|
||||
"""Agent that emits a tool call with a very long preview string."""
|
||||
LONG_CMD = "cd /home/teknium/.hermes/hermes-agent/.worktrees/hermes-d8860339 && source .venv/bin/activate && python -m pytest tests/gateway/test_run_progress_topics.py -n0 -q"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback("tool.started", "terminal", self.LONG_CMD, {})
|
||||
time.sleep(0.35)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_runner(adapter):
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
GatewayRunner = gateway_run.GatewayRunner
|
||||
@@ -235,102 +217,3 @@ async def test_run_agent_progress_uses_event_message_id_for_slack_dm(monkeypatch
|
||||
assert adapter.sent
|
||||
assert adapter.sent[0]["metadata"] == {"thread_id": "1234567890.000001"}
|
||||
assert all(call["metadata"] == {"thread_id": "1234567890.000001"} for call in adapter.typing)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview truncation tests (all/new mode respects tool_preview_length)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_long_preview_helper(monkeypatch, tmp_path, preview_length=0):
|
||||
"""Shared setup for long-preview truncation tests.
|
||||
|
||||
Returns (adapter, result) after running the agent with LongPreviewAgent.
|
||||
``preview_length`` controls display.tool_preview_length in the config file
|
||||
that _run_agent reads — so the gateway picks it up the same way production does.
|
||||
"""
|
||||
import asyncio
|
||||
import yaml
|
||||
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all")
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = LongPreviewAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
# Write config.yaml so _run_agent picks up tool_preview_length
|
||||
config = {"display": {"tool_preview_length": preview_length}}
|
||||
(tmp_path / "config.yaml").write_text(yaml.dump(config), encoding="utf-8")
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-trunc",
|
||||
session_key="agent:main:telegram:dm:12345",
|
||||
)
|
||||
)
|
||||
return adapter, result
|
||||
|
||||
|
||||
def test_all_mode_default_truncation_40_chars(monkeypatch, tmp_path):
|
||||
"""When tool_preview_length is 0 (default), all/new mode truncates to 40 chars."""
|
||||
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=0)
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
content = adapter.sent[0]["content"]
|
||||
# The long command should be truncated — total preview <= 40 chars
|
||||
assert "..." in content
|
||||
# Extract the preview part between quotes
|
||||
import re
|
||||
match = re.search(r'"(.+)"', content)
|
||||
assert match, f"No quoted preview found in: {content}"
|
||||
preview_text = match.group(1)
|
||||
assert len(preview_text) <= 40, f"Preview too long ({len(preview_text)}): {preview_text}"
|
||||
|
||||
|
||||
def test_all_mode_respects_custom_preview_length(monkeypatch, tmp_path):
|
||||
"""When tool_preview_length is explicitly set (e.g. 120), all/new mode uses that."""
|
||||
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=120)
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
content = adapter.sent[0]["content"]
|
||||
# With 120-char cap, the command (165 chars) should still be truncated but longer
|
||||
import re
|
||||
match = re.search(r'"(.+)"', content)
|
||||
assert match, f"No quoted preview found in: {content}"
|
||||
preview_text = match.group(1)
|
||||
# Should be longer than the 40-char default
|
||||
assert len(preview_text) > 40, f"Preview suspiciously short ({len(preview_text)}): {preview_text}"
|
||||
# But still capped at 120
|
||||
assert len(preview_text) <= 120, f"Preview too long ({len(preview_text)}): {preview_text}"
|
||||
|
||||
|
||||
def test_all_mode_no_truncation_when_preview_fits(monkeypatch, tmp_path):
|
||||
"""Short previews (under the cap) are not truncated."""
|
||||
# Set a generous cap — the LongPreviewAgent's command is ~165 chars
|
||||
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=200)
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
content = adapter.sent[0]["content"]
|
||||
# With a 200-char cap, the 165-char command should NOT be truncated
|
||||
assert "..." not in content, f"Preview was truncated when it shouldn't be: {content}"
|
||||
|
||||
@@ -51,8 +51,7 @@ def _make_runner(session_entry: SessionEntry):
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session_title.return_value = None
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
@@ -83,34 +82,12 @@ async def test_status_command_reports_running_agent_without_interrupt(monkeypatc
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
assert "**Session ID:** `sess-1`" in result
|
||||
assert "**Tokens:** 321" in result
|
||||
assert "**Agent Running:** Yes ⚡" in result
|
||||
assert "**Title:**" not in result
|
||||
running_agent.interrupt.assert_not_called()
|
||||
assert runner._pending_messages == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_includes_session_title_when_present():
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=321,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session_title.return_value = "My titled session"
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
assert "**Session ID:** `sess-1`" in result
|
||||
assert "**Title:** My titled session" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
"""Tests for TelegramPlatform._merge_caption caption deduplication logic."""
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
merge = TelegramAdapter._merge_caption
|
||||
|
||||
|
||||
class TestMergeCaptionBasic:
|
||||
def test_no_existing_text(self):
|
||||
assert merge(None, "Hello") == "Hello"
|
||||
|
||||
def test_empty_existing_text(self):
|
||||
assert merge("", "Hello") == "Hello"
|
||||
|
||||
def test_exact_duplicate_dropped(self):
|
||||
assert merge("Revenue", "Revenue") == "Revenue"
|
||||
|
||||
def test_different_captions_merged(self):
|
||||
result = merge("Q3 Results", "Q4 Projections")
|
||||
assert result == "Q3 Results\n\nQ4 Projections"
|
||||
|
||||
|
||||
class TestMergeCaptionSubstringBug:
|
||||
"""These are the exact scenarios that the old substring check got wrong."""
|
||||
|
||||
def test_shorter_caption_not_dropped_when_substring(self):
|
||||
# Bug: "Meeting" in "Meeting agenda" → True → caption was silently lost
|
||||
result = merge("Meeting agenda", "Meeting")
|
||||
assert result == "Meeting agenda\n\nMeeting"
|
||||
|
||||
def test_longer_caption_not_dropped_when_contains_existing(self):
|
||||
# "Revenue and Profit" contains "Revenue", but they are different captions
|
||||
result = merge("Revenue", "Revenue and Profit")
|
||||
assert result == "Revenue\n\nRevenue and Profit"
|
||||
|
||||
def test_prefix_caption_not_dropped(self):
|
||||
result = merge("Q3 Results - Revenue", "Q3 Results")
|
||||
assert result == "Q3 Results - Revenue\n\nQ3 Results"
|
||||
|
||||
|
||||
class TestMergeCaptionWhitespace:
|
||||
def test_trailing_space_treated_as_duplicate(self):
|
||||
assert merge("Revenue", "Revenue ") == "Revenue"
|
||||
|
||||
def test_leading_space_treated_as_duplicate(self):
|
||||
assert merge("Revenue", " Revenue") == "Revenue"
|
||||
|
||||
def test_whitespace_only_new_text_not_added(self):
|
||||
# strip() makes it empty string → falsy check in callers guards this,
|
||||
# but _merge_caption itself: strip matches "" which is not in list → would merge.
|
||||
# Callers already guard with `if event.text:` so this is an edge case.
|
||||
result = merge("Revenue", " ")
|
||||
# " ".strip() == "" → not in ["Revenue"] → gets merged (caller guards prevent this)
|
||||
assert "\n\n" in result or result == "Revenue"
|
||||
|
||||
|
||||
class TestMergeCaptionMultipleItems:
|
||||
def test_three_unique_captions_all_present(self):
|
||||
text = merge(None, "A")
|
||||
text = merge(text, "B")
|
||||
text = merge(text, "C")
|
||||
assert text == "A\n\nB\n\nC"
|
||||
|
||||
def test_duplicate_in_middle_dropped(self):
|
||||
text = merge(None, "A")
|
||||
text = merge(text, "B")
|
||||
text = merge(text, "A") # duplicate
|
||||
assert text == "A\n\nB"
|
||||
|
||||
def test_album_scenario_revenue_profit(self):
|
||||
# Album Item 1: "Revenue and Profit", Item 2: "Revenue"
|
||||
# Old bug: "Revenue" in ["Revenue and Profit"] → True → lost
|
||||
text = merge(None, "Revenue and Profit")
|
||||
text = merge(text, "Revenue")
|
||||
assert text == "Revenue and Profit\n\nRevenue"
|
||||
@@ -1,260 +0,0 @@
|
||||
"""Tests for Telegram message reactions tied to processing lifecycle hooks."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_adapter(**extra_env):
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
adapter = object.__new__(TelegramAdapter)
|
||||
adapter.platform = Platform.TELEGRAM
|
||||
adapter.config = PlatformConfig(enabled=True, token="fake-token")
|
||||
adapter._bot = AsyncMock()
|
||||
adapter._bot.set_message_reaction = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_event(chat_id: str = "123", message_id: str = "456") -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type="private",
|
||||
user_id="42",
|
||||
user_name="TestUser",
|
||||
),
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
|
||||
# ── _reactions_enabled ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_reactions_disabled_by_default(monkeypatch):
|
||||
"""Telegram reactions should be disabled by default."""
|
||||
monkeypatch.delenv("TELEGRAM_REACTIONS", raising=False)
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is False
|
||||
|
||||
|
||||
def test_reactions_enabled_when_set_true(monkeypatch):
|
||||
"""Setting TELEGRAM_REACTIONS=true enables reactions."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is True
|
||||
|
||||
|
||||
def test_reactions_enabled_with_1(monkeypatch):
|
||||
"""TELEGRAM_REACTIONS=1 enables reactions."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "1")
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is True
|
||||
|
||||
|
||||
def test_reactions_disabled_with_false(monkeypatch):
|
||||
"""TELEGRAM_REACTIONS=false disables reactions."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "false")
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is False
|
||||
|
||||
|
||||
def test_reactions_disabled_with_0(monkeypatch):
|
||||
"""TELEGRAM_REACTIONS=0 disables reactions."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "0")
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is False
|
||||
|
||||
|
||||
def test_reactions_disabled_with_no(monkeypatch):
|
||||
"""TELEGRAM_REACTIONS=no disables reactions."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "no")
|
||||
adapter = _make_adapter()
|
||||
assert adapter._reactions_enabled() is False
|
||||
|
||||
|
||||
# ── _set_reaction ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_reaction_calls_bot_api(monkeypatch):
|
||||
"""_set_reaction should call bot.set_message_reaction with correct args."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
|
||||
result = await adapter._set_reaction("123", "456", "\U0001f440")
|
||||
|
||||
assert result is True
|
||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||
chat_id=123,
|
||||
message_id=456,
|
||||
reaction="\U0001f440",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_reaction_returns_false_without_bot(monkeypatch):
|
||||
"""_set_reaction should return False when bot is not available."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
adapter._bot = None
|
||||
|
||||
result = await adapter._set_reaction("123", "456", "\U0001f440")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_reaction_handles_api_error_gracefully(monkeypatch):
|
||||
"""API errors during reaction should not propagate."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
adapter._bot.set_message_reaction = AsyncMock(side_effect=RuntimeError("no perms"))
|
||||
|
||||
result = await adapter._set_reaction("123", "456", "\U0001f440")
|
||||
assert result is False
|
||||
|
||||
|
||||
# ── on_processing_start ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_start_adds_eyes_reaction(monkeypatch):
|
||||
"""Processing start should add eyes reaction when enabled."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
|
||||
await adapter.on_processing_start(event)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||
chat_id=123,
|
||||
message_id=456,
|
||||
reaction="\U0001f440",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_start_skipped_when_disabled(monkeypatch):
|
||||
"""Processing start should not react when reactions are disabled."""
|
||||
monkeypatch.delenv("TELEGRAM_REACTIONS", raising=False)
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
|
||||
await adapter.on_processing_start(event)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_start_handles_missing_ids(monkeypatch):
|
||||
"""Should handle events without chat_id or message_id gracefully."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
event = MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=SimpleNamespace(chat_id=None),
|
||||
message_id=None,
|
||||
)
|
||||
|
||||
await adapter.on_processing_start(event)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_not_awaited()
|
||||
|
||||
|
||||
# ── on_processing_complete ───────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_success(monkeypatch):
|
||||
"""Successful processing should set check mark reaction."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
|
||||
await adapter.on_processing_complete(event, success=True)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||
chat_id=123,
|
||||
message_id=456,
|
||||
reaction="\u2705",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_failure(monkeypatch):
|
||||
"""Failed processing should set cross mark reaction."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
|
||||
await adapter.on_processing_complete(event, success=False)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||
chat_id=123,
|
||||
message_id=456,
|
||||
reaction="\u274c",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_skipped_when_disabled(monkeypatch):
|
||||
"""Processing complete should not react when reactions are disabled."""
|
||||
monkeypatch.delenv("TELEGRAM_REACTIONS", raising=False)
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
|
||||
await adapter.on_processing_complete(event, success=True)
|
||||
|
||||
adapter._bot.set_message_reaction.assert_not_awaited()
|
||||
|
||||
|
||||
# ── config.py bridging ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_config_bridges_telegram_reactions(monkeypatch, tmp_path):
|
||||
"""gateway/config.py bridges telegram.reactions to TELEGRAM_REACTIONS env var."""
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"telegram": {
|
||||
"reactions": True,
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Use setenv (not delenv) so monkeypatch registers cleanup even when
|
||||
# the var doesn't exist yet — load_gateway_config will overwrite it.
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "")
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
assert os.getenv("TELEGRAM_REACTIONS") == "true"
|
||||
|
||||
|
||||
def test_config_reactions_env_takes_precedence(monkeypatch, tmp_path):
|
||||
"""Env var should take precedence over config.yaml for reactions."""
|
||||
import yaml
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"telegram": {
|
||||
"reactions": True,
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "false")
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
load_gateway_config()
|
||||
|
||||
import os
|
||||
assert os.getenv("TELEGRAM_REACTIONS") == "false"
|
||||
@@ -590,15 +590,8 @@ class TestSessionIsolation:
|
||||
class TestDeliveryCleanup:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_info_survives_multiple_sends(self):
|
||||
"""send() must NOT pop delivery_info.
|
||||
|
||||
Interim status messages (fallback notifications, context-pressure
|
||||
warnings, etc.) flow through the same send() path as the final
|
||||
response. If the entry were popped on the first send, the final
|
||||
response would silently downgrade to the ``log`` deliver type.
|
||||
Regression test for that bug.
|
||||
"""
|
||||
async def test_delivery_info_cleaned_after_send(self):
|
||||
"""send() pops delivery_info so the entry doesn't leak memory."""
|
||||
adapter = _make_adapter()
|
||||
chat_id = "webhook:test:d-xyz"
|
||||
adapter._delivery_info[chat_id] = {
|
||||
@@ -606,40 +599,10 @@ class TestDeliveryCleanup:
|
||||
"deliver_extra": {},
|
||||
"payload": {"x": 1},
|
||||
}
|
||||
adapter._delivery_info_created[chat_id] = time.time()
|
||||
|
||||
# First send (e.g. an interim status message)
|
||||
result1 = await adapter.send(chat_id, "Status: switching to fallback")
|
||||
assert result1.success is True
|
||||
# Entry must still be present so the final send can read it
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
# Second send (the final agent response)
|
||||
result2 = await adapter.send(chat_id, "Final agent response")
|
||||
assert result2.success is True
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_info_pruned_via_ttl(self):
|
||||
"""Stale delivery_info entries are dropped on the next POST."""
|
||||
adapter = _make_adapter()
|
||||
adapter._idempotency_ttl = 60 # short TTL for the test
|
||||
now = time.time()
|
||||
|
||||
# Stale entry — older than TTL
|
||||
adapter._delivery_info["webhook:test:old"] = {"deliver": "log"}
|
||||
adapter._delivery_info_created["webhook:test:old"] = now - 120
|
||||
|
||||
# Fresh entry — should survive
|
||||
adapter._delivery_info["webhook:test:new"] = {"deliver": "log"}
|
||||
adapter._delivery_info_created["webhook:test:new"] = now - 5
|
||||
|
||||
adapter._prune_delivery_info(now)
|
||||
|
||||
assert "webhook:test:old" not in adapter._delivery_info
|
||||
assert "webhook:test:old" not in adapter._delivery_info_created
|
||||
assert "webhook:test:new" in adapter._delivery_info
|
||||
assert "webhook:test:new" in adapter._delivery_info_created
|
||||
result = await adapter.send(chat_id, "Agent response here")
|
||||
assert result.success is True
|
||||
assert chat_id not in adapter._delivery_info
|
||||
|
||||
|
||||
# ===================================================================
|
||||
|
||||
@@ -259,9 +259,8 @@ class TestCrossPlatformDelivery:
|
||||
mock_tg_adapter.send.assert_awaited_once_with(
|
||||
"12345", "I've acknowledged the alert.", metadata=None
|
||||
)
|
||||
# Delivery info is retained after send() so interim status messages
|
||||
# don't strand the final response (TTL-based cleanup happens on POST).
|
||||
assert chat_id in adapter._delivery_info
|
||||
# Delivery info should be cleaned up
|
||||
assert chat_id not in adapter._delivery_info
|
||||
|
||||
|
||||
# ===================================================================
|
||||
@@ -334,6 +333,5 @@ class TestGitHubCommentDelivery:
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
# Delivery info is retained after send() so interim status messages
|
||||
# don't strand the final response (TTL-based cleanup happens on POST).
|
||||
assert chat_id in adapter._delivery_info
|
||||
# Delivery info cleaned up
|
||||
assert chat_id not in adapter._delivery_info
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_format_banner_version_label_without_git_state():
|
||||
from hermes_cli import banner
|
||||
|
||||
with patch.object(banner, "get_git_banner_state", return_value=None):
|
||||
value = banner.format_banner_version_label()
|
||||
|
||||
assert value == f"Hermes Agent v{banner.VERSION} ({banner.RELEASE_DATE})"
|
||||
|
||||
|
||||
def test_format_banner_version_label_on_upstream_main():
|
||||
from hermes_cli import banner
|
||||
|
||||
with patch.object(
|
||||
banner,
|
||||
"get_git_banner_state",
|
||||
return_value={"upstream": "b2f477a3", "local": "b2f477a3", "ahead": 0},
|
||||
):
|
||||
value = banner.format_banner_version_label()
|
||||
|
||||
assert value.endswith("· upstream b2f477a3")
|
||||
assert "local" not in value
|
||||
|
||||
|
||||
def test_format_banner_version_label_with_carried_commits():
|
||||
from hermes_cli import banner
|
||||
|
||||
with patch.object(
|
||||
banner,
|
||||
"get_git_banner_state",
|
||||
return_value={"upstream": "b2f477a3", "local": "af8aad31", "ahead": 3},
|
||||
):
|
||||
value = banner.format_banner_version_label()
|
||||
|
||||
assert "upstream b2f477a3" in value
|
||||
assert "local af8aad31" in value
|
||||
assert "+3 carried commits" in value
|
||||
|
||||
|
||||
def test_get_git_banner_state_reads_origin_and_head(tmp_path):
|
||||
from hermes_cli import banner
|
||||
|
||||
repo_dir = tmp_path / "repo"
|
||||
(repo_dir / ".git").mkdir(parents=True)
|
||||
|
||||
results = {
|
||||
("git", "rev-parse", "--short=8", "origin/main"): MagicMock(returncode=0, stdout="b2f477a3\n"),
|
||||
("git", "rev-parse", "--short=8", "HEAD"): MagicMock(returncode=0, stdout="af8aad31\n"),
|
||||
("git", "rev-list", "--count", "origin/main..HEAD"): MagicMock(returncode=0, stdout="3\n"),
|
||||
}
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
key = tuple(cmd)
|
||||
if key not in results:
|
||||
raise AssertionError(f"unexpected command: {cmd}")
|
||||
return results[key]
|
||||
|
||||
with patch("hermes_cli.banner.subprocess.run", side_effect=fake_run):
|
||||
state = banner.get_git_banner_state(repo_dir)
|
||||
|
||||
assert state == {"upstream": "b2f477a3", "local": "af8aad31", "ahead": 3}
|
||||
@@ -641,69 +641,3 @@ class TestEnsureUserSystemdEnv:
|
||||
result = gateway_cli._systemctl_cmd(system=True)
|
||||
assert result == ["systemctl"]
|
||||
assert calls == []
|
||||
|
||||
|
||||
class TestProfileArg:
|
||||
"""Tests for _profile_arg — returns '--profile <name>' for named profiles."""
|
||||
|
||||
def test_default_hermes_home_returns_empty(self, tmp_path, monkeypatch):
|
||||
"""Default ~/.hermes should not produce a --profile flag."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
result = gateway_cli._profile_arg(str(hermes_home))
|
||||
assert result == ""
|
||||
|
||||
def test_named_profile_returns_flag(self, tmp_path, monkeypatch):
|
||||
"""~/.hermes/profiles/mybot should return '--profile mybot'."""
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "mybot"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
result = gateway_cli._profile_arg(str(profile_dir))
|
||||
assert result == "--profile mybot"
|
||||
|
||||
def test_hash_path_returns_empty(self, tmp_path, monkeypatch):
|
||||
"""Arbitrary non-profile HERMES_HOME should return empty string."""
|
||||
custom_home = tmp_path / "custom" / "hermes"
|
||||
custom_home.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
result = gateway_cli._profile_arg(str(custom_home))
|
||||
assert result == ""
|
||||
|
||||
def test_nested_profile_path_returns_empty(self, tmp_path, monkeypatch):
|
||||
"""~/.hermes/profiles/mybot/subdir should NOT match — too deep."""
|
||||
nested = tmp_path / ".hermes" / "profiles" / "mybot" / "subdir"
|
||||
nested.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
result = gateway_cli._profile_arg(str(nested))
|
||||
assert result == ""
|
||||
|
||||
def test_invalid_profile_name_returns_empty(self, tmp_path, monkeypatch):
|
||||
"""Profile names with invalid chars should not match the regex."""
|
||||
bad_profile = tmp_path / ".hermes" / "profiles" / "My Bot!"
|
||||
bad_profile.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
result = gateway_cli._profile_arg(str(bad_profile))
|
||||
assert result == ""
|
||||
|
||||
def test_systemd_unit_includes_profile(self, tmp_path, monkeypatch):
|
||||
"""generate_systemd_unit should include --profile in ExecStart for named profiles."""
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "mybot"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
|
||||
unit = gateway_cli.generate_systemd_unit(system=False)
|
||||
assert "--profile mybot" in unit
|
||||
assert "gateway run --replace" in unit
|
||||
|
||||
def test_launchd_plist_includes_profile(self, tmp_path, monkeypatch):
|
||||
"""generate_launchd_plist should include --profile in ProgramArguments for named profiles."""
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "mybot"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
|
||||
plist = gateway_cli.generate_launchd_plist()
|
||||
assert "<string>--profile</string>" in plist
|
||||
assert "<string>mybot</string>" in plist
|
||||
|
||||
@@ -72,45 +72,6 @@ def test_get_platform_tools_keeps_enabled_mcp_servers_with_explicit_builtin_sele
|
||||
assert "web-search-prime" in enabled
|
||||
|
||||
|
||||
def test_get_platform_tools_no_mcp_sentinel_excludes_all_mcp_servers():
|
||||
"""The 'no_mcp' sentinel in platform_toolsets excludes all MCP servers."""
|
||||
config = {
|
||||
"platform_toolsets": {"cli": ["web", "terminal", "no_mcp"]},
|
||||
"mcp_servers": {
|
||||
"exa": {"url": "https://mcp.exa.ai/mcp"},
|
||||
"web-search-prime": {"url": "https://api.z.ai/api/mcp/web_search_prime/mcp"},
|
||||
},
|
||||
}
|
||||
|
||||
enabled = _get_platform_tools(config, "cli")
|
||||
|
||||
assert "web" in enabled
|
||||
assert "terminal" in enabled
|
||||
assert "exa" not in enabled
|
||||
assert "web-search-prime" not in enabled
|
||||
assert "no_mcp" not in enabled
|
||||
|
||||
|
||||
def test_get_platform_tools_no_mcp_sentinel_does_not_affect_other_platforms():
|
||||
"""The 'no_mcp' sentinel only affects the platform it's configured on."""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"api_server": ["web", "terminal", "no_mcp"],
|
||||
},
|
||||
"mcp_servers": {
|
||||
"exa": {"url": "https://mcp.exa.ai/mcp"},
|
||||
},
|
||||
}
|
||||
|
||||
# api_server should exclude MCP
|
||||
api_enabled = _get_platform_tools(config, "api_server")
|
||||
assert "exa" not in api_enabled
|
||||
|
||||
# cli (not configured with no_mcp) should include MCP
|
||||
cli_enabled = _get_platform_tools(config, "cli")
|
||||
assert "exa" in cli_enabled
|
||||
|
||||
|
||||
def test_toolset_has_keys_for_vision_accepts_codex_auth(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "auth.json").write_text(
|
||||
|
||||
@@ -13,11 +13,10 @@ from plugins.memory.supermemory import (
|
||||
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, api_key: str, timeout: float, container_tag: str, search_mode: str = "hybrid"):
|
||||
def __init__(self, api_key: str, timeout: float, container_tag: str):
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.container_tag = container_tag
|
||||
self.search_mode = search_mode
|
||||
self.add_calls = []
|
||||
self.search_results = []
|
||||
self.profile_response = {"static": [], "dynamic": [], "search_results": []}
|
||||
@@ -25,27 +24,24 @@ class FakeClient:
|
||||
self.forgotten_ids = []
|
||||
self.forget_by_query_response = {"success": True, "message": "Forgot"}
|
||||
|
||||
def add_memory(self, content, metadata=None, *, entity_context="",
|
||||
container_tag=None, custom_id=None):
|
||||
def add_memory(self, content, metadata=None, *, entity_context=""):
|
||||
self.add_calls.append({
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
"entity_context": entity_context,
|
||||
"container_tag": container_tag,
|
||||
"custom_id": custom_id,
|
||||
})
|
||||
return {"id": "mem_123"}
|
||||
|
||||
def search_memories(self, query, *, limit=5, container_tag=None, search_mode=None):
|
||||
def search_memories(self, query, *, limit=5):
|
||||
return self.search_results
|
||||
|
||||
def get_profile(self, query=None, *, container_tag=None):
|
||||
def get_profile(self, query=None):
|
||||
return self.profile_response
|
||||
|
||||
def forget_memory(self, memory_id, *, container_tag=None):
|
||||
def forget_memory(self, memory_id):
|
||||
self.forgotten_ids.append(memory_id)
|
||||
|
||||
def forget_by_query(self, query, *, container_tag=None):
|
||||
def forget_by_query(self, query):
|
||||
return self.forget_by_query_response
|
||||
|
||||
def ingest_conversation(self, session_id, messages):
|
||||
@@ -86,8 +82,7 @@ def test_is_available_false_when_import_missing(monkeypatch):
|
||||
def test_load_and_save_config_round_trip(tmp_path):
|
||||
_save_supermemory_config({"container_tag": "demo-tag", "auto_capture": False}, str(tmp_path))
|
||||
cfg = _load_supermemory_config(str(tmp_path))
|
||||
# container_tag is kept raw — sanitization happens in initialize() after template resolution
|
||||
assert cfg["container_tag"] == "demo-tag"
|
||||
assert cfg["container_tag"] == "demo_tag"
|
||||
assert cfg["auto_capture"] is False
|
||||
assert cfg["auto_recall"] is True
|
||||
|
||||
@@ -181,8 +176,7 @@ def test_shutdown_joins_and_clears_threads(provider, monkeypatch):
|
||||
started = threading.Event()
|
||||
release = threading.Event()
|
||||
|
||||
def slow_add_memory(content, metadata=None, *, entity_context="",
|
||||
container_tag=None, custom_id=None):
|
||||
def slow_add_memory(content, metadata=None, *, entity_context=""):
|
||||
started.set()
|
||||
release.wait(timeout=1)
|
||||
provider._client.add_calls.append({
|
||||
@@ -261,151 +255,3 @@ def test_handle_tool_call_returns_error_when_unconfigured(monkeypatch):
|
||||
p = SupermemoryMemoryProvider()
|
||||
result = json.loads(p.handle_tool_call("supermemory_search", {"query": "x"}))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
# -- Identity template tests --------------------------------------------------
|
||||
|
||||
|
||||
def test_identity_template_resolved_in_container_tag(monkeypatch, tmp_path):
|
||||
"""container_tag with {identity} resolves to profile-scoped tag."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({"container_tag": "hermes-{identity}"}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli", agent_identity="coder")
|
||||
assert p._container_tag == "hermes_coder"
|
||||
|
||||
|
||||
def test_identity_template_default_profile(monkeypatch, tmp_path):
|
||||
"""Without agent_identity kwarg, {identity} resolves to 'default'."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({"container_tag": "hermes-{identity}"}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
assert p._container_tag == "hermes_default"
|
||||
|
||||
|
||||
def test_container_tag_env_var_override(monkeypatch, tmp_path):
|
||||
"""SUPERMEMORY_CONTAINER_TAG env var overrides config."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setenv("SUPERMEMORY_CONTAINER_TAG", "env-override")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
assert p._container_tag == "env_override"
|
||||
|
||||
|
||||
# -- Search mode tests --------------------------------------------------------
|
||||
|
||||
|
||||
def test_search_mode_config_passed_to_client(monkeypatch, tmp_path):
|
||||
"""search_mode from config is passed to _SupermemoryClient."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({"search_mode": "memories"}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
assert p._search_mode == "memories"
|
||||
assert p._client.search_mode == "memories"
|
||||
|
||||
|
||||
def test_invalid_search_mode_falls_back_to_default(monkeypatch, tmp_path):
|
||||
"""Invalid search_mode falls back to 'hybrid'."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({"search_mode": "invalid_mode"}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
assert p._search_mode == "hybrid"
|
||||
|
||||
|
||||
# -- Multi-container tests ----------------------------------------------------
|
||||
|
||||
|
||||
def test_multi_container_disabled_by_default(provider):
|
||||
"""Multi-container is off by default; schemas have no container_tag param."""
|
||||
assert provider._enable_custom_containers is False
|
||||
schemas = provider.get_tool_schemas()
|
||||
for s in schemas:
|
||||
assert "container_tag" not in s["parameters"]["properties"]
|
||||
|
||||
|
||||
def test_multi_container_enabled_adds_schema_param(monkeypatch, tmp_path):
|
||||
"""When enabled, tool schemas include container_tag parameter."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({
|
||||
"enable_custom_container_tags": True,
|
||||
"custom_containers": ["project-alpha", "shared"],
|
||||
}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
assert p._enable_custom_containers is True
|
||||
assert p._allowed_containers == ["hermes", "project_alpha", "shared"]
|
||||
schemas = p.get_tool_schemas()
|
||||
for s in schemas:
|
||||
assert "container_tag" in s["parameters"]["properties"]
|
||||
|
||||
|
||||
def test_multi_container_tool_store_with_custom_tag(monkeypatch, tmp_path):
|
||||
"""supermemory_store uses the resolved container_tag when multi-container is enabled."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({
|
||||
"enable_custom_container_tags": True,
|
||||
"custom_containers": ["project-alpha"],
|
||||
}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
result = json.loads(p.handle_tool_call("supermemory_store", {
|
||||
"content": "test memory",
|
||||
"container_tag": "project-alpha",
|
||||
}))
|
||||
assert result["saved"] is True
|
||||
assert result["container_tag"] == "project_alpha"
|
||||
assert p._client.add_calls[-1]["container_tag"] == "project_alpha"
|
||||
|
||||
|
||||
def test_multi_container_rejects_unlisted_tag(monkeypatch, tmp_path):
|
||||
"""Tool calls with a non-whitelisted container_tag return an error."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({
|
||||
"enable_custom_container_tags": True,
|
||||
"custom_containers": ["allowed-tag"],
|
||||
}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
result = json.loads(p.handle_tool_call("supermemory_store", {
|
||||
"content": "test",
|
||||
"container_tag": "forbidden-tag",
|
||||
}))
|
||||
assert "error" in result
|
||||
assert "not allowed" in result["error"]
|
||||
|
||||
|
||||
def test_multi_container_system_prompt_includes_instructions(monkeypatch, tmp_path):
|
||||
"""system_prompt_block includes container list and instructions when multi-container is enabled."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({
|
||||
"enable_custom_container_tags": True,
|
||||
"custom_containers": ["docs"],
|
||||
"custom_container_instructions": "Use docs for documentation context.",
|
||||
}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
block = p.system_prompt_block()
|
||||
assert "Multi-container mode enabled" in block
|
||||
assert "docs" in block
|
||||
assert "Use docs for documentation context." in block
|
||||
|
||||
|
||||
def test_get_config_schema_minimal():
|
||||
"""get_config_schema only returns the API key field."""
|
||||
p = SupermemoryMemoryProvider()
|
||||
schema = p.get_config_schema()
|
||||
assert len(schema) == 1
|
||||
assert schema[0]["key"] == "api_key"
|
||||
assert schema[0]["secret"] is True
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Tests for _save_oversized_tool_result() — the large tool response handler.
|
||||
|
||||
When a tool returns more than _LARGE_RESULT_CHARS characters, the full content
|
||||
is saved to a file and the model receives a preview + file path instead.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import (
|
||||
_save_oversized_tool_result,
|
||||
_LARGE_RESULT_CHARS,
|
||||
_LARGE_RESULT_PREVIEW_CHARS,
|
||||
)
|
||||
|
||||
|
||||
class TestSaveOversizedToolResult:
|
||||
"""Unit tests for the large tool result handler."""
|
||||
|
||||
def test_small_result_returned_unchanged(self):
|
||||
"""Results under the threshold pass through untouched."""
|
||||
small = "x" * 1000
|
||||
assert _save_oversized_tool_result("terminal", small) is small
|
||||
|
||||
def test_exactly_at_threshold_returned_unchanged(self):
|
||||
"""Results exactly at the threshold pass through."""
|
||||
exact = "y" * _LARGE_RESULT_CHARS
|
||||
assert _save_oversized_tool_result("terminal", exact) is exact
|
||||
|
||||
def test_oversized_result_saved_to_file(self, tmp_path, monkeypatch):
|
||||
"""Results over the threshold are written to a file."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
big = "A" * (_LARGE_RESULT_CHARS + 500)
|
||||
result = _save_oversized_tool_result("terminal", big)
|
||||
|
||||
# Should contain the preview
|
||||
assert result.startswith("A" * _LARGE_RESULT_PREVIEW_CHARS)
|
||||
# Should mention the file path
|
||||
assert "Full output saved to:" in result
|
||||
# Should mention original size
|
||||
assert f"{len(big):,}" in result
|
||||
|
||||
# Extract the file path and verify the file exists with full content
|
||||
match = re.search(r"Full output saved to: (.+?)\n", result)
|
||||
assert match, f"No file path found in result: {result[:300]}"
|
||||
filepath = match.group(1)
|
||||
assert os.path.isfile(filepath)
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
saved = f.read()
|
||||
assert saved == big
|
||||
assert len(saved) == _LARGE_RESULT_CHARS + 500
|
||||
|
||||
def test_file_placed_in_cache_tool_responses(self, tmp_path, monkeypatch):
|
||||
"""Saved file lives under HERMES_HOME/cache/tool_responses/."""
|
||||
hermes_home = str(tmp_path / ".hermes")
|
||||
monkeypatch.setenv("HERMES_HOME", hermes_home)
|
||||
os.makedirs(hermes_home, exist_ok=True)
|
||||
|
||||
big = "B" * (_LARGE_RESULT_CHARS + 1)
|
||||
result = _save_oversized_tool_result("web_search", big)
|
||||
|
||||
match = re.search(r"Full output saved to: (.+?)\n", result)
|
||||
filepath = match.group(1)
|
||||
expected_dir = os.path.join(hermes_home, "cache", "tool_responses")
|
||||
assert filepath.startswith(expected_dir)
|
||||
|
||||
def test_filename_contains_tool_name(self, tmp_path, monkeypatch):
|
||||
"""The saved filename includes a sanitized version of the tool name."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
big = "C" * (_LARGE_RESULT_CHARS + 1)
|
||||
result = _save_oversized_tool_result("browser_navigate", big)
|
||||
|
||||
match = re.search(r"Full output saved to: (.+?)\n", result)
|
||||
filename = os.path.basename(match.group(1))
|
||||
assert filename.startswith("browser_navigate_")
|
||||
assert filename.endswith(".txt")
|
||||
|
||||
def test_tool_name_sanitized(self, tmp_path, monkeypatch):
|
||||
"""Special characters in tool names are replaced in the filename."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
big = "D" * (_LARGE_RESULT_CHARS + 1)
|
||||
result = _save_oversized_tool_result("mcp:some/weird tool", big)
|
||||
|
||||
match = re.search(r"Full output saved to: (.+?)\n", result)
|
||||
filename = os.path.basename(match.group(1))
|
||||
# No slashes or colons in filename
|
||||
assert "/" not in filename
|
||||
assert ":" not in filename
|
||||
|
||||
def test_fallback_on_write_failure(self, tmp_path, monkeypatch):
|
||||
"""When file write fails, falls back to destructive truncation."""
|
||||
# Point HERMES_HOME to a path that will fail (file, not directory)
|
||||
bad_path = str(tmp_path / "not_a_dir.txt")
|
||||
with open(bad_path, "w") as f:
|
||||
f.write("I'm a file, not a directory")
|
||||
monkeypatch.setenv("HERMES_HOME", bad_path)
|
||||
|
||||
big = "E" * (_LARGE_RESULT_CHARS + 50_000)
|
||||
result = _save_oversized_tool_result("terminal", big)
|
||||
|
||||
# Should still contain data (fallback truncation)
|
||||
assert len(result) > 0
|
||||
assert result.startswith("E" * 1000)
|
||||
# Should mention the failure
|
||||
assert "File save failed" in result
|
||||
# Should be truncated to approximately _LARGE_RESULT_CHARS + error msg
|
||||
assert len(result) < len(big)
|
||||
|
||||
def test_preview_length_capped(self, tmp_path, monkeypatch):
|
||||
"""The inline preview is capped at _LARGE_RESULT_PREVIEW_CHARS."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
# Use distinct chars so we can measure the preview
|
||||
big = "Z" * (_LARGE_RESULT_CHARS + 5000)
|
||||
result = _save_oversized_tool_result("terminal", big)
|
||||
|
||||
# The preview section is the content before the "[Large tool response:" marker
|
||||
marker_pos = result.index("[Large tool response:")
|
||||
preview_section = result[:marker_pos].rstrip()
|
||||
assert len(preview_section) == _LARGE_RESULT_PREVIEW_CHARS
|
||||
|
||||
def test_guidance_message_mentions_tools(self, tmp_path, monkeypatch):
|
||||
"""The replacement message tells the model how to access the file."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
big = "F" * (_LARGE_RESULT_CHARS + 1)
|
||||
result = _save_oversized_tool_result("terminal", big)
|
||||
|
||||
assert "read_file" in result
|
||||
assert "search_files" in result
|
||||
|
||||
def test_empty_result_passes_through(self):
|
||||
"""Empty strings are not oversized."""
|
||||
assert _save_oversized_tool_result("terminal", "") == ""
|
||||
|
||||
def test_unicode_content_preserved(self, tmp_path, monkeypatch):
|
||||
"""Unicode content is fully preserved in the saved file."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
# Mix of ASCII and multi-byte unicode to exceed threshold
|
||||
unit = "Hello 世界! 🎉 " * 100 # ~1400 chars per repeat
|
||||
big = unit * ((_LARGE_RESULT_CHARS // len(unit)) + 1)
|
||||
assert len(big) > _LARGE_RESULT_CHARS
|
||||
|
||||
result = _save_oversized_tool_result("terminal", big)
|
||||
match = re.search(r"Full output saved to: (.+?)\n", result)
|
||||
filepath = match.group(1)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
saved = f.read()
|
||||
assert saved == big
|
||||
@@ -1011,9 +1011,10 @@ class TestExecuteToolCalls:
|
||||
big_result = "x" * 150_000
|
||||
with patch("run_agent.handle_function_call", return_value=big_result):
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
# Content should be replaced with persisted-output or truncation
|
||||
# Content should be replaced with preview + file path
|
||||
assert len(messages[0]["content"]) < 150_000
|
||||
assert ("Truncated" in messages[0]["content"] or "<persisted-output>" in messages[0]["content"])
|
||||
assert "Large tool response" in messages[0]["content"]
|
||||
assert "Full output saved to:" in messages[0]["content"]
|
||||
|
||||
|
||||
class TestConcurrentToolExecution:
|
||||
@@ -1248,7 +1249,8 @@ class TestConcurrentToolExecution:
|
||||
assert len(messages) == 2
|
||||
for m in messages:
|
||||
assert len(m["content"]) < 150_000
|
||||
assert ("Truncated" in m["content"] or "<persisted-output>" in m["content"])
|
||||
assert "Large tool response" in m["content"]
|
||||
assert "Full output saved to:" in m["content"]
|
||||
|
||||
def test_invoke_tool_dispatches_to_handle_function_call(self, agent):
|
||||
"""_invoke_tool should route regular tools through handle_function_call."""
|
||||
|
||||
@@ -386,56 +386,6 @@ def test_run_conversation_codex_plain_text(monkeypatch):
|
||||
assert result["messages"][-1]["content"] == "OK"
|
||||
|
||||
|
||||
def test_run_conversation_codex_empty_output_with_output_text(monkeypatch):
|
||||
"""Regression: empty response.output + valid output_text should succeed,
|
||||
not trigger retry/fallback. The validation stage must defer to
|
||||
_normalize_codex_response which synthesizes output from output_text."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
|
||||
def _empty_output_response(api_kwargs):
|
||||
return SimpleNamespace(
|
||||
output=[],
|
||||
output_text="Hello from Codex",
|
||||
usage=SimpleNamespace(input_tokens=5, output_tokens=3, total_tokens=8),
|
||||
status="completed",
|
||||
model="gpt-5-codex",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(agent, "_interruptible_api_call", _empty_output_response)
|
||||
|
||||
result = agent.run_conversation("Say hello")
|
||||
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "Hello from Codex"
|
||||
|
||||
|
||||
def test_run_conversation_codex_empty_output_no_output_text_retries(monkeypatch):
|
||||
"""When both output and output_text are empty, validation should
|
||||
correctly mark the response as invalid and trigger retry."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
calls = {"api": 0}
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
calls["api"] += 1
|
||||
if calls["api"] == 1:
|
||||
return SimpleNamespace(
|
||||
output=[],
|
||||
output_text=None,
|
||||
usage=SimpleNamespace(input_tokens=5, output_tokens=3, total_tokens=8),
|
||||
status="completed",
|
||||
model="gpt-5-codex",
|
||||
)
|
||||
return _codex_message_response("Recovered")
|
||||
|
||||
monkeypatch.setattr(agent, "_interruptible_api_call", _fake_api_call)
|
||||
|
||||
result = agent.run_conversation("Say hello")
|
||||
|
||||
assert calls["api"] >= 2
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "Recovered"
|
||||
|
||||
|
||||
def test_run_conversation_codex_refreshes_after_401_and_retries(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
calls = {"api": 0, "refresh": 0}
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cli import HermesCLI, _build_compact_banner, _rich_text_from_ansi
|
||||
from hermes_cli.skin_engine import get_active_skin, set_active_skin
|
||||
|
||||
|
||||
def _make_cli_stub():
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli._sudo_state = None
|
||||
cli._secret_state = None
|
||||
cli._approval_state = None
|
||||
cli._clarify_state = None
|
||||
cli._clarify_freetext = False
|
||||
cli._command_running = False
|
||||
cli._agent_running = False
|
||||
cli._voice_recording = False
|
||||
cli._voice_processing = False
|
||||
cli._voice_mode = False
|
||||
cli._command_spinner_frame = lambda: "⟳"
|
||||
cli._tui_style_base = {
|
||||
"prompt": "#fff",
|
||||
"input-area": "#fff",
|
||||
"input-rule": "#aaa",
|
||||
"prompt-working": "#888 italic",
|
||||
}
|
||||
cli._app = SimpleNamespace(style=None)
|
||||
cli._invalidate = MagicMock()
|
||||
return cli
|
||||
|
||||
|
||||
class TestCliSkinPromptIntegration:
|
||||
def test_default_prompt_fragments_use_default_symbol(self):
|
||||
cli = _make_cli_stub()
|
||||
|
||||
set_active_skin("default")
|
||||
assert cli._get_tui_prompt_fragments() == [("class:prompt", "❯ ")]
|
||||
|
||||
def test_ares_prompt_fragments_use_skin_symbol(self):
|
||||
cli = _make_cli_stub()
|
||||
|
||||
set_active_skin("ares")
|
||||
assert cli._get_tui_prompt_fragments() == [("class:prompt", "⚔ ❯ ")]
|
||||
|
||||
def test_secret_prompt_fragments_preserve_secret_state(self):
|
||||
cli = _make_cli_stub()
|
||||
cli._secret_state = {"response_queue": object()}
|
||||
|
||||
set_active_skin("ares")
|
||||
assert cli._get_tui_prompt_fragments() == [("class:sudo-prompt", "🔑 ❯ ")]
|
||||
|
||||
def test_icon_only_skin_symbol_still_visible_in_special_states(self):
|
||||
cli = _make_cli_stub()
|
||||
cli._secret_state = {"response_queue": object()}
|
||||
|
||||
with patch("hermes_cli.skin_engine.get_active_prompt_symbol", return_value="⚔ "):
|
||||
assert cli._get_tui_prompt_fragments() == [("class:sudo-prompt", "🔑 ⚔ ")]
|
||||
|
||||
def test_build_tui_style_dict_uses_skin_overrides(self):
|
||||
cli = _make_cli_stub()
|
||||
|
||||
set_active_skin("ares")
|
||||
skin = get_active_skin()
|
||||
style_dict = cli._build_tui_style_dict()
|
||||
|
||||
assert style_dict["prompt"] == skin.get_color("prompt")
|
||||
assert style_dict["input-rule"] == skin.get_color("input_rule")
|
||||
assert style_dict["prompt-working"] == f"{skin.get_color('banner_dim')} italic"
|
||||
assert style_dict["approval-title"] == f"{skin.get_color('ui_warn')} bold"
|
||||
|
||||
def test_apply_tui_skin_style_updates_running_app(self):
|
||||
cli = _make_cli_stub()
|
||||
|
||||
set_active_skin("ares")
|
||||
assert cli._apply_tui_skin_style() is True
|
||||
assert cli._app.style is not None
|
||||
cli._invalidate.assert_called_once_with(min_interval=0.0)
|
||||
|
||||
def test_handle_skin_command_refreshes_live_tui(self, capsys):
|
||||
cli = _make_cli_stub()
|
||||
|
||||
with patch("cli.save_config_value", return_value=True):
|
||||
cli._handle_skin_command("/skin ares")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Skin set to: ares (saved)" in output
|
||||
assert "Prompt + TUI colors updated." in output
|
||||
assert cli._app.style is not None
|
||||
|
||||
|
||||
class TestCompactBannerSkinIntegration:
|
||||
def test_default_compact_banner_keeps_legacy_nous_hermes_branding(self):
|
||||
set_active_skin("default")
|
||||
|
||||
with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \
|
||||
patch("cli.format_banner_version_label", return_value="Hermes Agent v0.1.0 (test)"):
|
||||
banner = _build_compact_banner()
|
||||
|
||||
assert "NOUS HERMES" in banner
|
||||
|
||||
def test_poseidon_compact_banner_uses_skin_branding_instead_of_nous_hermes(self):
|
||||
set_active_skin("poseidon")
|
||||
|
||||
with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \
|
||||
patch("cli.format_banner_version_label", return_value="Hermes Agent v0.1.0 (test)"):
|
||||
banner = _build_compact_banner()
|
||||
|
||||
assert "Poseidon Agent" in banner
|
||||
assert "NOUS HERMES" not in banner
|
||||
|
||||
def test_poseidon_compact_banner_uses_skin_colors(self):
|
||||
set_active_skin("poseidon")
|
||||
skin = get_active_skin()
|
||||
|
||||
with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \
|
||||
patch("cli.format_banner_version_label", return_value="Hermes Agent v0.1.0 (test)"):
|
||||
banner = _build_compact_banner()
|
||||
|
||||
assert skin.get_color("banner_border") in banner
|
||||
assert skin.get_color("banner_title") in banner
|
||||
assert skin.get_color("banner_dim") in banner
|
||||
|
||||
def test_compact_banner_shows_version_label(self):
|
||||
set_active_skin("default")
|
||||
|
||||
with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \
|
||||
patch("cli.format_banner_version_label", return_value="Hermes Agent v1.0 (test) · upstream abc12345"):
|
||||
banner = _build_compact_banner()
|
||||
|
||||
assert "upstream abc12345" in banner
|
||||
|
||||
|
||||
class TestAnsiRichTextHelper:
|
||||
def test_preserves_literal_brackets(self):
|
||||
text = _rich_text_from_ansi("[notatag] literal")
|
||||
assert text.plain == "[notatag] literal"
|
||||
|
||||
def test_strips_ansi_but_keeps_plain_text(self):
|
||||
text = _rich_text_from_ansi("\x1b[31mred\x1b[0m")
|
||||
assert text.plain == "red"
|
||||
@@ -1,118 +0,0 @@
|
||||
"""Tests for the scrolling viewport logic in _curses_prompt_choice (issue #5755).
|
||||
|
||||
The "More providers" submenu has 13 entries (11 extended + custom + cancel).
|
||||
Before the fix, _curses_prompt_choice rendered items starting unconditionally
|
||||
from index 0 with no scroll offset. On terminals shorter than ~16 rows, items
|
||||
near the bottom were never drawn. When the cursor wrapped from 0 to the last
|
||||
item (Cancel) via UP-arrow, the highlight rendered off-screen, leaving the menu
|
||||
looking like only "Cancel" existed.
|
||||
|
||||
The fix adds a scroll_offset that tracks the cursor so the highlighted item
|
||||
is always within the visible window. These tests exercise that logic in
|
||||
isolation without requiring a real TTY.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure scroll-offset logic extracted from _curses_menu for unit testing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _compute_scroll_offset(cursor: int, scroll_offset: int, visible: int, n_choices: int) -> int:
|
||||
"""Mirror of the scroll adjustment block inside _curses_menu."""
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible:
|
||||
scroll_offset = cursor - visible + 1
|
||||
scroll_offset = max(0, min(scroll_offset, max(0, n_choices - visible)))
|
||||
return scroll_offset
|
||||
|
||||
|
||||
def _visible_indices(cursor: int, scroll_offset: int, visible: int, n_choices: int):
|
||||
"""Return the list indices that would be rendered for the given state."""
|
||||
scroll_offset = _compute_scroll_offset(cursor, scroll_offset, visible, n_choices)
|
||||
return list(range(scroll_offset, min(scroll_offset + visible, n_choices)))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: scroll offset calculation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestScrollOffsetLogic:
|
||||
N = 13 # typical extended-providers list length
|
||||
|
||||
def test_cursor_at_zero_no_scroll(self):
|
||||
"""Start position: offset stays 0, first items visible."""
|
||||
assert _compute_scroll_offset(0, 0, 8, self.N) == 0
|
||||
|
||||
def test_cursor_within_window_unchanged(self):
|
||||
"""Cursor inside the current window: offset unchanged."""
|
||||
assert _compute_scroll_offset(5, 0, 8, self.N) == 0
|
||||
|
||||
def test_cursor_at_last_item_scrolls_down(self):
|
||||
"""Cursor on Cancel (index 12) with 8-row window: offset = 12 - 8 + 1 = 5."""
|
||||
offset = _compute_scroll_offset(12, 0, 8, self.N)
|
||||
assert offset == 5
|
||||
assert 12 in _visible_indices(12, 0, 8, self.N)
|
||||
|
||||
def test_cursor_wraps_to_cancel_via_up(self):
|
||||
"""UP from index 0 wraps to last item; last item must be visible."""
|
||||
wrapped_cursor = (0 - 1) % self.N # == 12
|
||||
indices = _visible_indices(wrapped_cursor, 0, 8, self.N)
|
||||
assert wrapped_cursor in indices
|
||||
|
||||
def test_cursor_above_window_scrolls_up(self):
|
||||
"""Cursor above current window: offset tracks cursor."""
|
||||
# window currently shows [5..12], cursor moves to 3
|
||||
offset = _compute_scroll_offset(3, 5, 8, self.N)
|
||||
assert offset == 3
|
||||
assert 3 in _visible_indices(3, 5, 8, self.N)
|
||||
|
||||
def test_visible_window_never_exceeds_list(self):
|
||||
"""Offset is clamped so the window never starts past the list end."""
|
||||
offset = _compute_scroll_offset(12, 0, 20, self.N) # window larger than list
|
||||
assert offset == 0
|
||||
|
||||
def test_single_item_list(self):
|
||||
"""Edge case: one choice, cursor 0."""
|
||||
assert _compute_scroll_offset(0, 0, 8, 1) == 0
|
||||
|
||||
def test_list_fits_in_window_no_scroll_needed(self):
|
||||
"""If all choices fit in the visible window, offset is always 0."""
|
||||
for cursor in range(self.N):
|
||||
offset = _compute_scroll_offset(cursor, 0, 20, self.N)
|
||||
assert offset == 0, f"cursor={cursor} should not scroll when window > list"
|
||||
|
||||
def test_cursor_always_in_visible_range(self):
|
||||
"""Invariant: cursor is always within the rendered window after adjustment."""
|
||||
visible = 5
|
||||
for cursor in range(self.N):
|
||||
indices = _visible_indices(cursor, 0, visible, self.N)
|
||||
assert cursor in indices, f"cursor={cursor} not in visible={indices}"
|
||||
|
||||
def test_full_navigation_down_cursor_always_visible(self):
|
||||
"""Simulate pressing DOWN through all items; cursor always in view."""
|
||||
visible = 6
|
||||
scroll_offset = 0
|
||||
cursor = 0
|
||||
for _ in range(self.N + 2): # wrap around twice
|
||||
scroll_offset = _compute_scroll_offset(cursor, scroll_offset, visible, self.N)
|
||||
rendered = list(range(scroll_offset, min(scroll_offset + visible, self.N)))
|
||||
assert cursor in rendered, f"cursor={cursor} not in rendered={rendered}"
|
||||
cursor = (cursor + 1) % self.N
|
||||
|
||||
def test_full_navigation_up_cursor_always_visible(self):
|
||||
"""Simulate pressing UP through all items; cursor always in view."""
|
||||
visible = 6
|
||||
scroll_offset = 0
|
||||
cursor = 0
|
||||
for _ in range(self.N + 2):
|
||||
scroll_offset = _compute_scroll_offset(cursor, scroll_offset, visible, self.N)
|
||||
rendered = list(range(scroll_offset, min(scroll_offset + visible, self.N)))
|
||||
assert cursor in rendered, f"cursor={cursor} not in rendered={rendered}"
|
||||
cursor = (cursor - 1) % self.N
|
||||
@@ -1,111 +0,0 @@
|
||||
"""Tests for MCP tool structuredContent preservation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools import mcp_tool
|
||||
|
||||
|
||||
class _FakeContentBlock:
|
||||
"""Minimal content block with .text and .type attributes."""
|
||||
|
||||
def __init__(self, text: str, block_type: str = "text"):
|
||||
self.text = text
|
||||
self.type = block_type
|
||||
|
||||
|
||||
class _FakeCallToolResult:
|
||||
"""Minimal CallToolResult stand-in.
|
||||
|
||||
Uses camelCase ``structuredContent`` / ``isError`` to match the real
|
||||
MCP SDK Pydantic model (``mcp.types.CallToolResult``).
|
||||
"""
|
||||
|
||||
def __init__(self, content, is_error=False, structuredContent=None):
|
||||
self.content = content
|
||||
self.isError = is_error
|
||||
self.structuredContent = structuredContent
|
||||
|
||||
|
||||
def _fake_run_on_mcp_loop(coro, timeout=30):
|
||||
"""Run an MCP coroutine directly in a fresh event loop."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _patch_mcp_server():
|
||||
"""Patch _servers and the MCP event loop so _make_tool_handler can run."""
|
||||
fake_session = MagicMock()
|
||||
fake_server = SimpleNamespace(session=fake_session)
|
||||
with patch.dict(mcp_tool._servers, {"test-server": fake_server}), \
|
||||
patch("tools.mcp_tool._run_on_mcp_loop", side_effect=_fake_run_on_mcp_loop):
|
||||
yield fake_session
|
||||
|
||||
|
||||
class TestStructuredContentPreservation:
|
||||
"""Ensure structuredContent from CallToolResult is forwarded."""
|
||||
|
||||
def test_text_only_result(self, _patch_mcp_server):
|
||||
"""When no structuredContent, result is text-only (existing behaviour)."""
|
||||
session = _patch_mcp_server
|
||||
session.call_tool = AsyncMock(
|
||||
return_value=_FakeCallToolResult(
|
||||
content=[_FakeContentBlock("hello")],
|
||||
)
|
||||
)
|
||||
handler = mcp_tool._make_tool_handler("test-server", "my-tool", 30.0)
|
||||
raw = handler({})
|
||||
data = json.loads(raw)
|
||||
assert data == {"result": "hello"}
|
||||
|
||||
def test_structured_content_is_the_result(self, _patch_mcp_server):
|
||||
"""When structuredContent is present, it becomes the result directly."""
|
||||
session = _patch_mcp_server
|
||||
payload = {"value": "secret-123", "revealed": True}
|
||||
session.call_tool = AsyncMock(
|
||||
return_value=_FakeCallToolResult(
|
||||
content=[_FakeContentBlock("OK")],
|
||||
structuredContent=payload,
|
||||
)
|
||||
)
|
||||
handler = mcp_tool._make_tool_handler("test-server", "my-tool", 30.0)
|
||||
raw = handler({})
|
||||
data = json.loads(raw)
|
||||
assert data["result"] == payload
|
||||
|
||||
def test_structured_content_none_falls_back_to_text(self, _patch_mcp_server):
|
||||
"""When structuredContent is explicitly None, fall back to text."""
|
||||
session = _patch_mcp_server
|
||||
session.call_tool = AsyncMock(
|
||||
return_value=_FakeCallToolResult(
|
||||
content=[_FakeContentBlock("done")],
|
||||
structuredContent=None,
|
||||
)
|
||||
)
|
||||
handler = mcp_tool._make_tool_handler("test-server", "my-tool", 30.0)
|
||||
raw = handler({})
|
||||
data = json.loads(raw)
|
||||
assert data == {"result": "done"}
|
||||
|
||||
def test_empty_text_with_structured_content(self, _patch_mcp_server):
|
||||
"""When content blocks are empty but structuredContent exists."""
|
||||
session = _patch_mcp_server
|
||||
payload = {"status": "ok", "data": [1, 2, 3]}
|
||||
session.call_tool = AsyncMock(
|
||||
return_value=_FakeCallToolResult(
|
||||
content=[],
|
||||
structuredContent=payload,
|
||||
)
|
||||
)
|
||||
handler = mcp_tool._make_tool_handler("test-server", "my-tool", 30.0)
|
||||
raw = handler({})
|
||||
data = json.loads(raw)
|
||||
assert data["result"] == payload
|
||||
@@ -1,472 +0,0 @@
|
||||
"""Tests for tools/tool_result_storage.py -- 3-layer tool result persistence."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tools.budget_config import (
|
||||
DEFAULT_RESULT_SIZE_CHARS,
|
||||
DEFAULT_TURN_BUDGET_CHARS,
|
||||
DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
BudgetConfig,
|
||||
)
|
||||
from tools.tool_result_storage import (
|
||||
HEREDOC_MARKER,
|
||||
PERSISTED_OUTPUT_TAG,
|
||||
PERSISTED_OUTPUT_CLOSING_TAG,
|
||||
STORAGE_DIR,
|
||||
_build_persisted_message,
|
||||
_heredoc_marker,
|
||||
_write_to_sandbox,
|
||||
enforce_turn_budget,
|
||||
generate_preview,
|
||||
maybe_persist_tool_result,
|
||||
)
|
||||
|
||||
|
||||
# ── generate_preview ──────────────────────────────────────────────────
|
||||
|
||||
class TestGeneratePreview:
|
||||
def test_short_content_unchanged(self):
|
||||
text = "short result"
|
||||
preview, has_more = generate_preview(text)
|
||||
assert preview == text
|
||||
assert has_more is False
|
||||
|
||||
def test_long_content_truncated(self):
|
||||
text = "x" * 5000
|
||||
preview, has_more = generate_preview(text, max_chars=2000)
|
||||
assert len(preview) <= 2000
|
||||
assert has_more is True
|
||||
|
||||
def test_truncates_at_newline_boundary(self):
|
||||
# 1500 chars + newline + 600 chars (past halfway)
|
||||
text = "a" * 1500 + "\n" + "b" * 600
|
||||
preview, has_more = generate_preview(text, max_chars=2000)
|
||||
assert preview == "a" * 1500 + "\n"
|
||||
assert has_more is True
|
||||
|
||||
def test_ignores_early_newline(self):
|
||||
# Newline at position 100, well before halfway of 2000
|
||||
text = "a" * 100 + "\n" + "b" * 3000
|
||||
preview, has_more = generate_preview(text, max_chars=2000)
|
||||
assert len(preview) == 2000
|
||||
assert has_more is True
|
||||
|
||||
def test_empty_content(self):
|
||||
preview, has_more = generate_preview("")
|
||||
assert preview == ""
|
||||
assert has_more is False
|
||||
|
||||
def test_exact_boundary(self):
|
||||
text = "x" * DEFAULT_PREVIEW_SIZE_CHARS
|
||||
preview, has_more = generate_preview(text)
|
||||
assert preview == text
|
||||
assert has_more is False
|
||||
|
||||
|
||||
# ── _heredoc_marker ───────────────────────────────────────────────────
|
||||
|
||||
class TestHeredocMarker:
|
||||
def test_default_marker_when_no_collision(self):
|
||||
assert _heredoc_marker("normal content") == HEREDOC_MARKER
|
||||
|
||||
def test_uuid_marker_on_collision(self):
|
||||
content = f"some text with {HEREDOC_MARKER} embedded"
|
||||
marker = _heredoc_marker(content)
|
||||
assert marker != HEREDOC_MARKER
|
||||
assert marker.startswith("HERMES_PERSIST_")
|
||||
assert marker not in content
|
||||
|
||||
|
||||
# ── _write_to_sandbox ─────────────────────────────────────────────────
|
||||
|
||||
class TestWriteToSandbox:
|
||||
def test_success(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
result = _write_to_sandbox("hello world", "/tmp/hermes-results/abc.txt", env)
|
||||
assert result is True
|
||||
env.execute.assert_called_once()
|
||||
cmd = env.execute.call_args[0][0]
|
||||
assert "mkdir -p" in cmd
|
||||
assert "hello world" in cmd
|
||||
assert HEREDOC_MARKER in cmd
|
||||
|
||||
def test_failure_returns_false(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "error", "returncode": 1}
|
||||
result = _write_to_sandbox("content", "/tmp/hermes-results/abc.txt", env)
|
||||
assert result is False
|
||||
|
||||
def test_heredoc_collision_uses_uuid_marker(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = f"text with {HEREDOC_MARKER} inside"
|
||||
_write_to_sandbox(content, "/tmp/hermes-results/abc.txt", env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
# The default marker should NOT be used as the delimiter
|
||||
lines = cmd.split("\n")
|
||||
# The first and last lines contain the actual delimiter
|
||||
assert HEREDOC_MARKER not in lines[0].split("<<")[1]
|
||||
|
||||
def test_timeout_passed(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
_write_to_sandbox("content", "/tmp/hermes-results/abc.txt", env)
|
||||
assert env.execute.call_args[1]["timeout"] == 30
|
||||
|
||||
|
||||
# ── _build_persisted_message ──────────────────────────────────────────
|
||||
|
||||
class TestBuildPersistedMessage:
|
||||
def test_structure(self):
|
||||
msg = _build_persisted_message(
|
||||
preview="first 100 chars...",
|
||||
has_more=True,
|
||||
original_size=50_000,
|
||||
file_path="/tmp/hermes-results/test123.txt",
|
||||
)
|
||||
assert msg.startswith(PERSISTED_OUTPUT_TAG)
|
||||
assert msg.endswith(PERSISTED_OUTPUT_CLOSING_TAG)
|
||||
assert "50,000 characters" in msg
|
||||
assert "/tmp/hermes-results/test123.txt" in msg
|
||||
assert "read_file" in msg
|
||||
assert "first 100 chars..." in msg
|
||||
assert "..." in msg # has_more indicator
|
||||
|
||||
def test_no_ellipsis_when_complete(self):
|
||||
msg = _build_persisted_message(
|
||||
preview="complete content",
|
||||
has_more=False,
|
||||
original_size=16,
|
||||
file_path="/tmp/hermes-results/x.txt",
|
||||
)
|
||||
# Should not have the trailing "..." indicator before closing tag
|
||||
lines = msg.strip().split("\n")
|
||||
assert lines[-2] != "..."
|
||||
|
||||
def test_large_size_shows_mb(self):
|
||||
msg = _build_persisted_message(
|
||||
preview="x",
|
||||
has_more=True,
|
||||
original_size=2_000_000,
|
||||
file_path="/tmp/hermes-results/big.txt",
|
||||
)
|
||||
assert "MB" in msg
|
||||
|
||||
|
||||
# ── maybe_persist_tool_result ─────────────────────────────────────────
|
||||
|
||||
class TestMaybePersistToolResult:
|
||||
def test_below_threshold_returns_unchanged(self):
|
||||
content = "small result"
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_123",
|
||||
env=None,
|
||||
threshold=50_000,
|
||||
)
|
||||
assert result == content
|
||||
|
||||
def test_above_threshold_with_env_persists(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_456",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert PERSISTED_OUTPUT_TAG in result
|
||||
assert "tc_456.txt" in result
|
||||
assert len(result) < len(content)
|
||||
env.execute.assert_called_once()
|
||||
|
||||
def test_persists_full_content_as_is(self):
|
||||
"""Content is persisted verbatim — no JSON extraction."""
|
||||
import json
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
raw = "line1\nline2\n" * 5_000
|
||||
content = json.dumps({"output": raw, "exit_code": 0, "error": None})
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_json",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert PERSISTED_OUTPUT_TAG in result
|
||||
# The heredoc written to sandbox should contain the full JSON blob
|
||||
cmd = env.execute.call_args[0][0]
|
||||
assert '"exit_code"' in cmd
|
||||
|
||||
def test_above_threshold_no_env_truncates_inline(self):
|
||||
content = "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_789",
|
||||
env=None,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert PERSISTED_OUTPUT_TAG not in result
|
||||
assert "Truncated" in result
|
||||
assert len(result) < len(content)
|
||||
|
||||
def test_env_write_failure_falls_back_to_truncation(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "disk full", "returncode": 1}
|
||||
content = "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_fail",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert PERSISTED_OUTPUT_TAG not in result
|
||||
assert "Truncated" in result
|
||||
|
||||
def test_env_execute_exception_falls_back(self):
|
||||
env = MagicMock()
|
||||
env.execute.side_effect = RuntimeError("connection lost")
|
||||
content = "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_exc",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert "Truncated" in result
|
||||
|
||||
def test_read_file_never_persisted(self):
|
||||
"""read_file has threshold=inf, should never be persisted."""
|
||||
env = MagicMock()
|
||||
content = "x" * 200_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="read_file",
|
||||
tool_use_id="tc_rf",
|
||||
env=env,
|
||||
threshold=float("inf"),
|
||||
)
|
||||
assert result == content
|
||||
env.execute.assert_not_called()
|
||||
|
||||
def test_uses_registry_threshold_when_not_provided(self):
|
||||
"""When threshold=None, looks up from registry."""
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = "x" * 60_000
|
||||
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_max_result_size.return_value = 30_000
|
||||
|
||||
with patch("tools.registry.registry", mock_registry):
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_reg",
|
||||
env=env,
|
||||
threshold=None,
|
||||
)
|
||||
# Should have persisted since 60K > 30K
|
||||
assert PERSISTED_OUTPUT_TAG in result or "Truncated" in result
|
||||
|
||||
def test_unicode_content_survives(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = "日本語テスト " * 10_000 # ~60K chars of unicode
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_uni",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert PERSISTED_OUTPUT_TAG in result
|
||||
# Preview should contain unicode
|
||||
assert "日本語テスト" in result
|
||||
|
||||
def test_empty_content_returns_unchanged(self):
|
||||
result = maybe_persist_tool_result(
|
||||
content="",
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_empty",
|
||||
env=None,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert result == ""
|
||||
|
||||
def test_whitespace_only_below_threshold(self):
|
||||
content = " " * 100
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_ws",
|
||||
env=None,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert result == content
|
||||
|
||||
def test_file_path_uses_tool_use_id(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="unique_id_abc",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert "unique_id_abc.txt" in result
|
||||
|
||||
def test_preview_included_in_persisted_output(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
# Create content with a distinctive start
|
||||
content = "DISTINCTIVE_START_MARKER" + "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_prev",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert "DISTINCTIVE_START_MARKER" in result
|
||||
|
||||
def test_threshold_zero_forces_persist(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = "even short content"
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_zero",
|
||||
env=env,
|
||||
threshold=0,
|
||||
)
|
||||
# Any non-empty content with threshold=0 should be persisted
|
||||
assert PERSISTED_OUTPUT_TAG in result
|
||||
|
||||
|
||||
# ── enforce_turn_budget ───────────────────────────────────────────────
|
||||
|
||||
class TestEnforceTurnBudget:
|
||||
def test_under_budget_no_changes(self):
|
||||
msgs = [
|
||||
{"role": "tool", "tool_call_id": "t1", "content": "small"},
|
||||
{"role": "tool", "tool_call_id": "t2", "content": "also small"},
|
||||
]
|
||||
result = enforce_turn_budget(msgs, env=None, config=BudgetConfig(turn_budget=200_000))
|
||||
assert result[0]["content"] == "small"
|
||||
assert result[1]["content"] == "also small"
|
||||
|
||||
def test_over_budget_largest_persisted_first(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
msgs = [
|
||||
{"role": "tool", "tool_call_id": "t1", "content": "a" * 80_000},
|
||||
{"role": "tool", "tool_call_id": "t2", "content": "b" * 130_000},
|
||||
]
|
||||
# Total 210K > 200K budget
|
||||
enforce_turn_budget(msgs, env=env, config=BudgetConfig(turn_budget=200_000))
|
||||
# The larger one (130K) should be persisted first
|
||||
assert PERSISTED_OUTPUT_TAG in msgs[1]["content"]
|
||||
|
||||
def test_already_persisted_results_skipped(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
msgs = [
|
||||
{"role": "tool", "tool_call_id": "t1",
|
||||
"content": f"{PERSISTED_OUTPUT_TAG}\nalready persisted\n{PERSISTED_OUTPUT_CLOSING_TAG}"},
|
||||
{"role": "tool", "tool_call_id": "t2", "content": "x" * 250_000},
|
||||
]
|
||||
enforce_turn_budget(msgs, env=env, config=BudgetConfig(turn_budget=200_000))
|
||||
# t1 should be untouched (already persisted)
|
||||
assert msgs[0]["content"].startswith(PERSISTED_OUTPUT_TAG)
|
||||
# t2 should be persisted
|
||||
assert PERSISTED_OUTPUT_TAG in msgs[1]["content"]
|
||||
|
||||
def test_medium_result_regression(self):
|
||||
"""6 results of 42K chars each (252K total) — each under 50K default
|
||||
threshold but aggregate exceeds 200K budget. L3 should persist."""
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
msgs = [
|
||||
{"role": "tool", "tool_call_id": f"t{i}", "content": "x" * 42_000}
|
||||
for i in range(6)
|
||||
]
|
||||
enforce_turn_budget(msgs, env=env, config=BudgetConfig(turn_budget=200_000))
|
||||
# At least some results should be persisted to get under 200K
|
||||
persisted_count = sum(
|
||||
1 for m in msgs if PERSISTED_OUTPUT_TAG in m["content"]
|
||||
)
|
||||
assert persisted_count >= 2 # Need to shed at least ~52K
|
||||
|
||||
def test_no_env_falls_back_to_truncation(self):
|
||||
msgs = [
|
||||
{"role": "tool", "tool_call_id": "t1", "content": "x" * 250_000},
|
||||
]
|
||||
enforce_turn_budget(msgs, env=None, config=BudgetConfig(turn_budget=200_000))
|
||||
# Should be truncated (no sandbox available)
|
||||
assert "Truncated" in msgs[0]["content"] or PERSISTED_OUTPUT_TAG in msgs[0]["content"]
|
||||
|
||||
def test_returns_same_list(self):
|
||||
msgs = [{"role": "tool", "tool_call_id": "t1", "content": "ok"}]
|
||||
result = enforce_turn_budget(msgs, env=None, config=BudgetConfig(turn_budget=200_000))
|
||||
assert result is msgs
|
||||
|
||||
def test_empty_messages(self):
|
||||
result = enforce_turn_budget([], env=None, config=BudgetConfig(turn_budget=200_000))
|
||||
assert result == []
|
||||
|
||||
|
||||
# ── Per-tool threshold integration ────────────────────────────────────
|
||||
|
||||
class TestPerToolThresholds:
|
||||
"""Verify registry wiring for per-tool thresholds."""
|
||||
|
||||
def test_registry_has_get_max_result_size(self):
|
||||
from tools.registry import registry
|
||||
assert hasattr(registry, "get_max_result_size")
|
||||
|
||||
def test_default_threshold(self):
|
||||
from tools.registry import registry
|
||||
# Unknown tool should return the default
|
||||
val = registry.get_max_result_size("nonexistent_tool_xyz")
|
||||
assert val == DEFAULT_RESULT_SIZE_CHARS
|
||||
|
||||
def test_terminal_threshold(self):
|
||||
from tools.registry import registry
|
||||
# Trigger import of terminal_tool to register the tool
|
||||
try:
|
||||
import tools.terminal_tool # noqa: F401
|
||||
val = registry.get_max_result_size("terminal")
|
||||
assert val == 30_000
|
||||
except ImportError:
|
||||
pytest.skip("terminal_tool not importable in test env")
|
||||
|
||||
def test_read_file_never_persisted(self):
|
||||
from tools.registry import registry
|
||||
try:
|
||||
import tools.file_tools # noqa: F401
|
||||
val = registry.get_max_result_size("read_file")
|
||||
assert val == float("inf")
|
||||
except ImportError:
|
||||
pytest.skip("file_tools not importable in test env")
|
||||
|
||||
def test_search_files_threshold(self):
|
||||
from tools.registry import registry
|
||||
try:
|
||||
import tools.file_tools # noqa: F401
|
||||
val = registry.get_max_result_size("search_files")
|
||||
assert val == 20_000
|
||||
except ImportError:
|
||||
pytest.skip("file_tools not importable in test env")
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Binary file extensions to skip for text-based operations.
|
||||
|
||||
These files can't be meaningfully compared as text and are often large.
|
||||
Ported from free-code src/constants/files.ts.
|
||||
"""
|
||||
|
||||
BINARY_EXTENSIONS = frozenset({
|
||||
# Images
|
||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".ico", ".webp", ".tiff", ".tif",
|
||||
# Videos
|
||||
".mp4", ".mov", ".avi", ".mkv", ".webm", ".wmv", ".flv", ".m4v", ".mpeg", ".mpg",
|
||||
# Audio
|
||||
".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a", ".wma", ".aiff", ".opus",
|
||||
# Archives
|
||||
".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz", ".z", ".tgz", ".iso",
|
||||
# Executables/binaries
|
||||
".exe", ".dll", ".so", ".dylib", ".bin", ".o", ".a", ".obj", ".lib",
|
||||
".app", ".msi", ".deb", ".rpm",
|
||||
# Documents (exclude .pdf — text-based, agents may want to inspect)
|
||||
".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx",
|
||||
".odt", ".ods", ".odp",
|
||||
# Fonts
|
||||
".ttf", ".otf", ".woff", ".woff2", ".eot",
|
||||
# Bytecode / VM artifacts
|
||||
".pyc", ".pyo", ".class", ".jar", ".war", ".ear", ".node", ".wasm", ".rlib",
|
||||
# Database files
|
||||
".sqlite", ".sqlite3", ".db", ".mdb", ".idx",
|
||||
# Design / 3D
|
||||
".psd", ".ai", ".eps", ".sketch", ".fig", ".xd", ".blend", ".3ds", ".max",
|
||||
# Flash
|
||||
".swf", ".fla",
|
||||
# Lock/profiling data
|
||||
".lockb", ".dat", ".data",
|
||||
})
|
||||
|
||||
|
||||
def has_binary_extension(path: str) -> bool:
|
||||
"""Check if a file path has a binary extension. Pure string check, no I/O."""
|
||||
dot = path.rfind(".")
|
||||
if dot == -1:
|
||||
return False
|
||||
return path[dot:].lower() in BINARY_EXTENSIONS
|
||||
+40
-24
@@ -146,11 +146,15 @@ def _get_command_timeout() -> int:
|
||||
``DEFAULT_COMMAND_TIMEOUT`` (30s) if unset or unreadable.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
cfg = read_raw_config()
|
||||
val = cfg.get("browser", {}).get("command_timeout")
|
||||
if val is not None:
|
||||
return max(int(val), 5) # Floor at 5s to avoid instant kills
|
||||
hermes_home = get_hermes_home()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
if config_path.exists():
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
val = cfg.get("browser", {}).get("command_timeout")
|
||||
if val is not None:
|
||||
return max(int(val), 5) # Floor at 5s to avoid instant kills
|
||||
except Exception as e:
|
||||
logger.debug("Could not read command_timeout from config: %s", e)
|
||||
return DEFAULT_COMMAND_TIMEOUT
|
||||
@@ -255,19 +259,23 @@ def _get_cloud_provider() -> Optional[CloudBrowserProvider]:
|
||||
|
||||
_cloud_provider_resolved = True
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
cfg = read_raw_config()
|
||||
browser_cfg = cfg.get("browser", {})
|
||||
provider_key = None
|
||||
if isinstance(browser_cfg, dict) and "cloud_provider" in browser_cfg:
|
||||
provider_key = normalize_browser_cloud_provider(
|
||||
browser_cfg.get("cloud_provider")
|
||||
)
|
||||
if provider_key == "local":
|
||||
_cached_cloud_provider = None
|
||||
return None
|
||||
if provider_key and provider_key in _PROVIDER_REGISTRY:
|
||||
_cached_cloud_provider = _PROVIDER_REGISTRY[provider_key]()
|
||||
hermes_home = get_hermes_home()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
if config_path.exists():
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
browser_cfg = cfg.get("browser", {})
|
||||
provider_key = None
|
||||
if isinstance(browser_cfg, dict) and "cloud_provider" in browser_cfg:
|
||||
provider_key = normalize_browser_cloud_provider(
|
||||
browser_cfg.get("cloud_provider")
|
||||
)
|
||||
if provider_key == "local":
|
||||
_cached_cloud_provider = None
|
||||
return None
|
||||
if provider_key and provider_key in _PROVIDER_REGISTRY:
|
||||
_cached_cloud_provider = _PROVIDER_REGISTRY[provider_key]()
|
||||
except Exception as e:
|
||||
logger.debug("Could not read cloud_provider from config: %s", e)
|
||||
|
||||
@@ -318,9 +326,13 @@ def _allow_private_urls() -> bool:
|
||||
_allow_private_urls_resolved = True
|
||||
_cached_allow_private_urls = False # safe default
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
cfg = read_raw_config()
|
||||
_cached_allow_private_urls = bool(cfg.get("browser", {}).get("allow_private_urls"))
|
||||
hermes_home = get_hermes_home()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
if config_path.exists():
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
_cached_allow_private_urls = bool(cfg.get("browser", {}).get("allow_private_urls"))
|
||||
except Exception as e:
|
||||
logger.debug("Could not read allow_private_urls from config: %s", e)
|
||||
return _cached_allow_private_urls
|
||||
@@ -1614,10 +1626,14 @@ def _maybe_start_recording(task_id: str):
|
||||
if task_id in _recording_sessions:
|
||||
return
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
hermes_home = get_hermes_home()
|
||||
cfg = read_raw_config()
|
||||
record_enabled = cfg.get("browser", {}).get("record_sessions", False)
|
||||
config_path = hermes_home / "config.yaml"
|
||||
record_enabled = False
|
||||
if config_path.exists():
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
record_enabled = cfg.get("browser", {}).get("record_sessions", False)
|
||||
|
||||
if not record_enabled:
|
||||
return
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Configurable budget constants for tool result persistence.
|
||||
|
||||
Overridable at the RL environment level via HermesAgentEnvConfig fields.
|
||||
Per-tool resolution: pinned > config overrides > registry > default.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
# Tools whose thresholds must never be overridden.
|
||||
# read_file=inf prevents infinite persist->read->persist loops.
|
||||
PINNED_THRESHOLDS: Dict[str, float] = {
|
||||
"read_file": float("inf"),
|
||||
}
|
||||
|
||||
# Defaults matching the current hardcoded values in tool_result_storage.py.
|
||||
# Kept here as the single source of truth; tool_result_storage.py imports these.
|
||||
DEFAULT_RESULT_SIZE_CHARS: int = 50_000
|
||||
DEFAULT_TURN_BUDGET_CHARS: int = 200_000
|
||||
DEFAULT_PREVIEW_SIZE_CHARS: int = 2_000
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BudgetConfig:
|
||||
"""Immutable budget constants for the 3-layer tool result persistence system.
|
||||
|
||||
Layer 2 (per-result): resolve_threshold(tool_name) -> threshold in chars.
|
||||
Layer 3 (per-turn): turn_budget -> aggregate char budget across all tool
|
||||
results in a single assistant turn.
|
||||
Preview: preview_size -> inline snippet size after persistence.
|
||||
"""
|
||||
|
||||
default_result_size: int = DEFAULT_RESULT_SIZE_CHARS
|
||||
turn_budget: int = DEFAULT_TURN_BUDGET_CHARS
|
||||
preview_size: int = DEFAULT_PREVIEW_SIZE_CHARS
|
||||
tool_overrides: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def resolve_threshold(self, tool_name: str) -> int | float:
|
||||
"""Resolve the persistence threshold for a tool.
|
||||
|
||||
Priority: pinned -> tool_overrides -> registry per-tool -> default.
|
||||
"""
|
||||
if tool_name in PINNED_THRESHOLDS:
|
||||
return PINNED_THRESHOLDS[tool_name]
|
||||
if tool_name in self.tool_overrides:
|
||||
return self.tool_overrides[tool_name]
|
||||
from tools.registry import registry
|
||||
return registry.get_max_result_size(tool_name, default=self.default_result_size)
|
||||
|
||||
|
||||
# Default config -- matches current hardcoded behavior exactly.
|
||||
DEFAULT_BUDGET = BudgetConfig()
|
||||
@@ -1343,5 +1343,4 @@ registry.register(
|
||||
enabled_tools=kw.get("enabled_tools")),
|
||||
check_fn=check_sandbox_requirements,
|
||||
emoji="🐍",
|
||||
max_result_size_chars=30_000,
|
||||
)
|
||||
|
||||
+33
-29
@@ -137,36 +137,40 @@ def _load_config_files() -> List[Dict[str, str]]:
|
||||
|
||||
result: List[Dict[str, str]] = []
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
hermes_home = _resolve_hermes_home()
|
||||
cfg = read_raw_config()
|
||||
cred_files = cfg.get("terminal", {}).get("credential_files")
|
||||
if isinstance(cred_files, list):
|
||||
hermes_home_resolved = hermes_home.resolve()
|
||||
for item in cred_files:
|
||||
if isinstance(item, str) and item.strip():
|
||||
rel = item.strip()
|
||||
if os.path.isabs(rel):
|
||||
logger.warning(
|
||||
"credential_files: rejected absolute config path %r", rel,
|
||||
)
|
||||
continue
|
||||
host_path = (hermes_home / rel).resolve()
|
||||
try:
|
||||
host_path.relative_to(hermes_home_resolved)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"credential_files: rejected config path traversal %r "
|
||||
"(resolves to %s, outside HERMES_HOME %s)",
|
||||
rel, host_path, hermes_home_resolved,
|
||||
)
|
||||
continue
|
||||
if host_path.is_file():
|
||||
container_path = f"/root/.hermes/{rel}"
|
||||
result.append({
|
||||
"host_path": str(host_path),
|
||||
"container_path": container_path,
|
||||
})
|
||||
config_path = hermes_home / "config.yaml"
|
||||
if config_path.exists():
|
||||
import yaml
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
cred_files = cfg.get("terminal", {}).get("credential_files")
|
||||
if isinstance(cred_files, list):
|
||||
hermes_home_resolved = hermes_home.resolve()
|
||||
for item in cred_files:
|
||||
if isinstance(item, str) and item.strip():
|
||||
rel = item.strip()
|
||||
if os.path.isabs(rel):
|
||||
logger.warning(
|
||||
"credential_files: rejected absolute config path %r", rel,
|
||||
)
|
||||
continue
|
||||
host_path = (hermes_home / rel).resolve()
|
||||
try:
|
||||
host_path.relative_to(hermes_home_resolved)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"credential_files: rejected config path traversal %r "
|
||||
"(resolves to %s, outside HERMES_HOME %s)",
|
||||
rel, host_path, hermes_home_resolved,
|
||||
)
|
||||
continue
|
||||
if host_path.is_file():
|
||||
container_path = f"/root/.hermes/{rel}"
|
||||
result.append({
|
||||
"host_path": str(host_path),
|
||||
"container_path": container_path,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug("Could not read terminal.credential_files from config: %s", e)
|
||||
|
||||
|
||||
@@ -66,13 +66,18 @@ def _load_config_passthrough() -> frozenset[str]:
|
||||
|
||||
result: set[str] = set()
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
cfg = read_raw_config()
|
||||
passthrough = cfg.get("terminal", {}).get("env_passthrough")
|
||||
if isinstance(passthrough, list):
|
||||
for item in passthrough:
|
||||
if isinstance(item, str) and item.strip():
|
||||
result.add(item.strip())
|
||||
from hermes_constants import get_hermes_home
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
if config_path.exists():
|
||||
import yaml
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
passthrough = cfg.get("terminal", {}).get("env_passthrough")
|
||||
if isinstance(passthrough, list):
|
||||
for item in passthrough:
|
||||
if isinstance(item, str) and item.strip():
|
||||
result.add(item.strip())
|
||||
except Exception as e:
|
||||
logger.debug("Could not read tools.env_passthrough from config: %s", e)
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.local import _HERMES_PROVIDER_ENV_BLOCKLIST
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -511,8 +510,6 @@ class DockerEnvironment(BaseEnvironment):
|
||||
forward_keys |= get_all_passthrough()
|
||||
except Exception:
|
||||
pass
|
||||
# Strip Hermes-managed secrets so they never leak into the container.
|
||||
forward_keys -= _HERMES_PROVIDER_ENV_BLOCKLIST
|
||||
hermes_env = _load_hermes_env_vars() if forward_keys else {}
|
||||
for key in sorted(forward_keys):
|
||||
value = os.getenv(key)
|
||||
|
||||
@@ -33,7 +33,6 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.binary_extensions import BINARY_EXTENSIONS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -281,6 +280,26 @@ class FileOperations(ABC):
|
||||
# Shell-based Implementation
|
||||
# =============================================================================
|
||||
|
||||
# Binary file extensions (fast path check)
|
||||
BINARY_EXTENSIONS = {
|
||||
# Images
|
||||
'.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.ico', '.tiff', '.tif',
|
||||
'.svg', # SVG is text but often treated as binary
|
||||
# Audio/Video
|
||||
'.mp3', '.mp4', '.wav', '.avi', '.mov', '.mkv', '.flac', '.ogg', '.webm',
|
||||
# Archives
|
||||
'.zip', '.tar', '.gz', '.bz2', '.xz', '.7z', '.rar',
|
||||
# Documents
|
||||
'.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx',
|
||||
# Compiled/Binary
|
||||
'.exe', '.dll', '.so', '.dylib', '.o', '.a', '.pyc', '.pyo', '.class',
|
||||
'.wasm', '.bin',
|
||||
# Fonts
|
||||
'.ttf', '.otf', '.woff', '.woff2', '.eot',
|
||||
# Other
|
||||
'.db', '.sqlite', '.sqlite3',
|
||||
}
|
||||
|
||||
# Image extensions (subset of binary that we can return as base64)
|
||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.ico'}
|
||||
|
||||
|
||||
+9
-45
@@ -7,7 +7,6 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from tools.binary_extensions import has_binary_extension
|
||||
from tools.file_operations import ShellFileOperations
|
||||
from agent.redact import redact_sensitive_text
|
||||
|
||||
@@ -26,8 +25,6 @@ _EXPECTED_WRITE_ERRNOS = {errno.EACCES, errno.EPERM, errno.EROFS}
|
||||
# Configurable via config.yaml: file_read_max_chars: 200000
|
||||
# ---------------------------------------------------------------------------
|
||||
_DEFAULT_MAX_READ_CHARS = 100_000
|
||||
_PRE_READ_MAX_BYTES = 256_000 # reject full-file reads on files larger than this
|
||||
_DEFAULT_READ_LIMIT = 500
|
||||
_max_read_chars_cached: int | None = None
|
||||
|
||||
|
||||
@@ -279,7 +276,7 @@ def clear_file_ops_cache(task_id: str = None):
|
||||
_file_ops_cache.clear()
|
||||
|
||||
|
||||
def read_file_tool(path: str, offset: int = 1, limit: int | None = None, task_id: str = "default") -> str:
|
||||
def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = "default") -> str:
|
||||
"""Read a file with pagination and line numbers."""
|
||||
try:
|
||||
# ── Device path guard ─────────────────────────────────────────
|
||||
@@ -293,22 +290,11 @@ def read_file_tool(path: str, offset: int = 1, limit: int | None = None, task_id
|
||||
),
|
||||
})
|
||||
|
||||
_resolved = Path(path).expanduser().resolve()
|
||||
|
||||
# ── Binary file guard ─────────────────────────────────────────
|
||||
# Block binary files by extension (no I/O).
|
||||
if has_binary_extension(str(_resolved)):
|
||||
_ext = _resolved.suffix.lower()
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"Cannot read binary file '{path}' ({_ext}). "
|
||||
"Use vision_analyze for images, or terminal to inspect binary files."
|
||||
),
|
||||
})
|
||||
|
||||
# ── Hermes internal path guard ────────────────────────────────
|
||||
# Prevent prompt injection via catalog or hub metadata files.
|
||||
import pathlib as _pathlib
|
||||
from hermes_constants import get_hermes_home as _get_hh
|
||||
_resolved = _pathlib.Path(path).expanduser().resolve()
|
||||
_hermes_home = _get_hh().resolve()
|
||||
_blocked_dirs = [
|
||||
_hermes_home / "skills" / ".hub" / "index-cache",
|
||||
@@ -327,28 +313,6 @@ def read_file_tool(path: str, offset: int = 1, limit: int | None = None, task_id
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# ── Pre-read file size guard ──────────────────────────────────
|
||||
# Guard only when the caller omits limit; an explicit limit means
|
||||
# the caller knows what slice it wants.
|
||||
if limit is None:
|
||||
try:
|
||||
_fsize = os.path.getsize(str(_resolved))
|
||||
except OSError:
|
||||
_fsize = 0
|
||||
if _fsize > _PRE_READ_MAX_BYTES:
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"File is too large to read in full ({_fsize:,} bytes). "
|
||||
f"Use offset and limit parameters to read specific sections "
|
||||
f"(e.g. offset=1, limit=100 for the first 100 lines)."
|
||||
),
|
||||
"path": path,
|
||||
"file_size": _fsize,
|
||||
}, ensure_ascii=False)
|
||||
|
||||
if limit is None:
|
||||
limit = _DEFAULT_READ_LIMIT
|
||||
|
||||
# ── Dedup check ───────────────────────────────────────────────
|
||||
# If we already read this exact (path, offset, limit) and the
|
||||
# file hasn't been modified since, return a lightweight stub
|
||||
@@ -762,7 +726,7 @@ def _check_file_reqs():
|
||||
|
||||
READ_FILE_SCHEMA = {
|
||||
"name": "read_file",
|
||||
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. When you already know which part of the file you need, only read that part using offset and limit — this is important for larger files. Files over 256KB will be rejected unless you provide a limit parameter. NOTE: Cannot read images or binary files — use vision_analyze for images.",
|
||||
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. Reads exceeding ~100K characters are rejected; use offset and limit to read specific sections of large files. NOTE: Cannot read images or binary files — use vision_analyze for images.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -826,7 +790,7 @@ SEARCH_FILES_SCHEMA = {
|
||||
|
||||
def _handle_read_file(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
return read_file_tool(path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit"), task_id=tid)
|
||||
return read_file_tool(path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid)
|
||||
|
||||
|
||||
def _handle_write_file(args, **kw):
|
||||
@@ -853,7 +817,7 @@ def _handle_search_files(args, **kw):
|
||||
output_mode=args.get("output_mode", "content"), context=args.get("context", 0), task_id=tid)
|
||||
|
||||
|
||||
registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs, emoji="📖", max_result_size_chars=float('inf'))
|
||||
registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs, emoji="✍️", max_result_size_chars=100_000)
|
||||
registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs, emoji="🔧", max_result_size_chars=100_000)
|
||||
registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs, emoji="🔎", max_result_size_chars=20_000)
|
||||
registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs, emoji="📖")
|
||||
registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs, emoji="✍️")
|
||||
registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs, emoji="🔧")
|
||||
registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs, emoji="🔎")
|
||||
|
||||
+1
-7
@@ -1253,13 +1253,7 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
for block in (result.content or []):
|
||||
if hasattr(block, "text"):
|
||||
parts.append(block.text)
|
||||
text_result = "\n".join(parts) if parts else ""
|
||||
|
||||
# Prefer structuredContent (machine-readable JSON) over plain text
|
||||
structured = getattr(result, "structuredContent", None)
|
||||
if structured is not None:
|
||||
return json.dumps({"result": structured})
|
||||
return json.dumps({"result": text_result})
|
||||
return json.dumps({"result": "\n".join(parts) if parts else ""})
|
||||
|
||||
try:
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
|
||||
+1
-16
@@ -27,12 +27,10 @@ class ToolEntry:
|
||||
__slots__ = (
|
||||
"name", "toolset", "schema", "handler", "check_fn",
|
||||
"requires_env", "is_async", "description", "emoji",
|
||||
"max_result_size_chars",
|
||||
)
|
||||
|
||||
def __init__(self, name, toolset, schema, handler, check_fn,
|
||||
requires_env, is_async, description, emoji,
|
||||
max_result_size_chars=None):
|
||||
requires_env, is_async, description, emoji):
|
||||
self.name = name
|
||||
self.toolset = toolset
|
||||
self.schema = schema
|
||||
@@ -42,7 +40,6 @@ class ToolEntry:
|
||||
self.is_async = is_async
|
||||
self.description = description
|
||||
self.emoji = emoji
|
||||
self.max_result_size_chars = max_result_size_chars
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
@@ -67,7 +64,6 @@ class ToolRegistry:
|
||||
is_async: bool = False,
|
||||
description: str = "",
|
||||
emoji: str = "",
|
||||
max_result_size_chars: int | float | None = None,
|
||||
):
|
||||
"""Register a tool. Called at module-import time by each tool file."""
|
||||
existing = self._tools.get(name)
|
||||
@@ -87,7 +83,6 @@ class ToolRegistry:
|
||||
is_async=is_async,
|
||||
description=description or schema.get("description", ""),
|
||||
emoji=emoji,
|
||||
max_result_size_chars=max_result_size_chars,
|
||||
)
|
||||
if check_fn and toolset not in self._toolset_checks:
|
||||
self._toolset_checks[toolset] = check_fn
|
||||
@@ -169,16 +164,6 @@ class ToolRegistry:
|
||||
# Query helpers (replace redundant dicts in model_tools.py)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float:
|
||||
"""Return per-tool max result size, or *default* (or global default)."""
|
||||
entry = self._tools.get(name)
|
||||
if entry and entry.max_result_size_chars is not None:
|
||||
return entry.max_result_size_chars
|
||||
if default is not None:
|
||||
return default
|
||||
from tools.budget_config import DEFAULT_RESULT_SIZE_CHARS
|
||||
return DEFAULT_RESULT_SIZE_CHARS
|
||||
|
||||
def get_all_tool_names(self) -> List[str]:
|
||||
"""Return sorted list of all registered tool names."""
|
||||
return sorted(self._tools.keys())
|
||||
|
||||
@@ -811,12 +811,6 @@ def _stop_cleanup_thread():
|
||||
pass
|
||||
|
||||
|
||||
def get_active_env(task_id: str):
|
||||
"""Return the active BaseEnvironment for *task_id*, or None."""
|
||||
with _env_lock:
|
||||
return _active_environments.get(task_id)
|
||||
|
||||
|
||||
def get_active_environments_info() -> Dict[str, Any]:
|
||||
"""Get information about currently active environments."""
|
||||
info = {
|
||||
@@ -1623,5 +1617,4 @@ registry.register(
|
||||
handler=_handle_terminal,
|
||||
check_fn=check_terminal_requirements,
|
||||
emoji="💻",
|
||||
max_result_size_chars=30_000,
|
||||
)
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
"""Tool result persistence -- preserves large outputs instead of truncating.
|
||||
|
||||
Defense against context-window overflow operates at three levels:
|
||||
|
||||
1. **Per-tool output cap** (inside each tool): Tools like search_files
|
||||
pre-truncate their own output before returning. This is the first line
|
||||
of defense and the only one the tool author controls.
|
||||
|
||||
2. **Per-result persistence** (maybe_persist_tool_result): After a tool
|
||||
returns, if its output exceeds the tool's registered threshold
|
||||
(registry.get_max_result_size), the full output is written INTO THE
|
||||
SANDBOX at /tmp/hermes-results/{tool_use_id}.txt via env.execute().
|
||||
The in-context content is replaced with a preview + file path reference.
|
||||
The model can read_file to access the full output on any backend.
|
||||
|
||||
3. **Per-turn aggregate budget** (enforce_turn_budget): After all tool
|
||||
results in a single assistant turn are collected, if the total exceeds
|
||||
MAX_TURN_BUDGET_CHARS (200K), the largest non-persisted results are
|
||||
spilled to disk until the aggregate is under budget. This catches cases
|
||||
where many medium-sized results combine to overflow context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from tools.budget_config import (
|
||||
DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
BudgetConfig,
|
||||
DEFAULT_BUDGET,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
PERSISTED_OUTPUT_TAG = "<persisted-output>"
|
||||
PERSISTED_OUTPUT_CLOSING_TAG = "</persisted-output>"
|
||||
STORAGE_DIR = "/tmp/hermes-results"
|
||||
HEREDOC_MARKER = "HERMES_PERSIST_EOF"
|
||||
_BUDGET_TOOL_NAME = "__budget_enforcement__"
|
||||
|
||||
|
||||
def generate_preview(content: str, max_chars: int = DEFAULT_PREVIEW_SIZE_CHARS) -> tuple[str, bool]:
|
||||
"""Truncate at last newline within max_chars. Returns (preview, has_more)."""
|
||||
if len(content) <= max_chars:
|
||||
return content, False
|
||||
truncated = content[:max_chars]
|
||||
last_nl = truncated.rfind("\n")
|
||||
if last_nl > max_chars // 2:
|
||||
truncated = truncated[:last_nl + 1]
|
||||
return truncated, True
|
||||
|
||||
|
||||
def _heredoc_marker(content: str) -> str:
|
||||
"""Return a heredoc delimiter that doesn't collide with content."""
|
||||
if HEREDOC_MARKER not in content:
|
||||
return HEREDOC_MARKER
|
||||
return f"HERMES_PERSIST_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def _write_to_sandbox(content: str, remote_path: str, env) -> bool:
|
||||
"""Write content into the sandbox via env.execute(). Returns True on success."""
|
||||
marker = _heredoc_marker(content)
|
||||
cmd = (
|
||||
f"mkdir -p {STORAGE_DIR} && cat > {remote_path} << '{marker}'\n"
|
||||
f"{content}\n"
|
||||
f"{marker}"
|
||||
)
|
||||
result = env.execute(cmd, timeout=30)
|
||||
return result.get("returncode", 1) == 0
|
||||
|
||||
|
||||
def _build_persisted_message(
|
||||
preview: str,
|
||||
has_more: bool,
|
||||
original_size: int,
|
||||
file_path: str,
|
||||
) -> str:
|
||||
"""Build the <persisted-output> replacement block."""
|
||||
size_kb = original_size / 1024
|
||||
if size_kb >= 1024:
|
||||
size_str = f"{size_kb / 1024:.1f} MB"
|
||||
else:
|
||||
size_str = f"{size_kb:.1f} KB"
|
||||
|
||||
msg = f"{PERSISTED_OUTPUT_TAG}\n"
|
||||
msg += f"This tool result was too large ({original_size:,} characters, {size_str}).\n"
|
||||
msg += f"Full output saved to: {file_path}\n"
|
||||
msg += "Use the read_file tool with offset and limit to access specific sections of this output.\n\n"
|
||||
msg += f"Preview (first {len(preview)} chars):\n"
|
||||
msg += preview
|
||||
if has_more:
|
||||
msg += "\n..."
|
||||
msg += f"\n{PERSISTED_OUTPUT_CLOSING_TAG}"
|
||||
return msg
|
||||
|
||||
|
||||
def maybe_persist_tool_result(
|
||||
content: str,
|
||||
tool_name: str,
|
||||
tool_use_id: str,
|
||||
env=None,
|
||||
config: BudgetConfig = DEFAULT_BUDGET,
|
||||
threshold: int | float | None = None,
|
||||
) -> str:
|
||||
"""Layer 2: persist oversized result into the sandbox, return preview + path.
|
||||
|
||||
Writes via env.execute() so the file is accessible from any backend
|
||||
(local, Docker, SSH, Modal, Daytona). Falls back to inline truncation
|
||||
if write fails or no env is available.
|
||||
|
||||
Args:
|
||||
content: Raw tool result string.
|
||||
tool_name: Name of the tool (used for threshold lookup).
|
||||
tool_use_id: Unique ID for this tool call (used as filename).
|
||||
env: The active BaseEnvironment instance, or None.
|
||||
config: BudgetConfig controlling thresholds and preview size.
|
||||
threshold: Explicit override; takes precedence over config resolution.
|
||||
|
||||
Returns:
|
||||
Original content if small, or <persisted-output> replacement.
|
||||
"""
|
||||
effective_threshold = threshold if threshold is not None else config.resolve_threshold(tool_name)
|
||||
|
||||
if effective_threshold == float("inf"):
|
||||
return content
|
||||
|
||||
if len(content) <= effective_threshold:
|
||||
return content
|
||||
|
||||
remote_path = f"{STORAGE_DIR}/{tool_use_id}.txt"
|
||||
preview, has_more = generate_preview(content, max_chars=config.preview_size)
|
||||
|
||||
if env is not None:
|
||||
try:
|
||||
if _write_to_sandbox(content, remote_path, env):
|
||||
logger.info(
|
||||
"Persisted large tool result: %s (%s, %d chars -> %s)",
|
||||
tool_name, tool_use_id, len(content), remote_path,
|
||||
)
|
||||
return _build_persisted_message(preview, has_more, len(content), remote_path)
|
||||
except Exception as exc:
|
||||
logger.warning("Sandbox write failed for %s: %s", tool_use_id, exc)
|
||||
|
||||
logger.info(
|
||||
"Inline-truncating large tool result: %s (%d chars, no sandbox write)",
|
||||
tool_name, len(content),
|
||||
)
|
||||
return (
|
||||
f"{preview}\n\n"
|
||||
f"[Truncated: tool response was {len(content):,} chars. "
|
||||
f"Full output could not be saved to sandbox.]"
|
||||
)
|
||||
|
||||
|
||||
def enforce_turn_budget(
|
||||
tool_messages: list[dict],
|
||||
env=None,
|
||||
config: BudgetConfig = DEFAULT_BUDGET,
|
||||
) -> list[dict]:
|
||||
"""Layer 3: enforce aggregate budget across all tool results in a turn.
|
||||
|
||||
If total chars exceed budget, persist the largest non-persisted results
|
||||
first (via sandbox write) until under budget. Already-persisted results
|
||||
are skipped.
|
||||
|
||||
Mutates the list in-place and returns it.
|
||||
"""
|
||||
candidates = []
|
||||
total_size = 0
|
||||
for i, msg in enumerate(tool_messages):
|
||||
content = msg.get("content", "")
|
||||
size = len(content)
|
||||
total_size += size
|
||||
if PERSISTED_OUTPUT_TAG not in content:
|
||||
candidates.append((i, size))
|
||||
|
||||
if total_size <= config.turn_budget:
|
||||
return tool_messages
|
||||
|
||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
for idx, size in candidates:
|
||||
if total_size <= config.turn_budget:
|
||||
break
|
||||
msg = tool_messages[idx]
|
||||
content = msg["content"]
|
||||
tool_use_id = msg.get("tool_call_id", f"budget_{idx}")
|
||||
|
||||
replacement = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name=_BUDGET_TOOL_NAME,
|
||||
tool_use_id=tool_use_id,
|
||||
env=env,
|
||||
config=config,
|
||||
threshold=0,
|
||||
)
|
||||
if replacement != content:
|
||||
total_size -= size
|
||||
total_size += len(replacement)
|
||||
tool_messages[idx]["content"] = replacement
|
||||
logger.info(
|
||||
"Budget enforcement: persisted tool result %s (%d chars)",
|
||||
tool_use_id, size,
|
||||
)
|
||||
|
||||
return tool_messages
|
||||
@@ -98,8 +98,12 @@ def get_stt_model_from_config() -> Optional[str]:
|
||||
Silently returns ``None`` on any error (missing file, bad YAML, etc.).
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
return read_raw_config().get("stt", {}).get("model")
|
||||
import yaml
|
||||
cfg_path = get_hermes_home() / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
return data.get("stt", {}).get("model")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
@@ -295,17 +299,7 @@ def _transcribe_local(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
_local_model = WhisperModel(model_name, device="auto", compute_type="auto")
|
||||
_local_model_name = model_name
|
||||
|
||||
# Language: config.yaml (stt.local.language) > env var > auto-detect.
|
||||
_forced_lang = (
|
||||
_load_stt_config().get("local", {}).get("language")
|
||||
or os.getenv(LOCAL_STT_LANGUAGE_ENV)
|
||||
or None
|
||||
)
|
||||
transcribe_kwargs = {"beam_size": 5}
|
||||
if _forced_lang:
|
||||
transcribe_kwargs["language"] = _forced_lang
|
||||
|
||||
segments, info = _local_model.transcribe(file_path, **transcribe_kwargs)
|
||||
segments, info = _local_model.transcribe(file_path, beam_size=5)
|
||||
transcript = " ".join(segment.text.strip() for segment in segments)
|
||||
|
||||
logger.info(
|
||||
@@ -354,12 +348,7 @@ def _transcribe_local_command(file_path: str, model_name: str) -> Dict[str, Any]
|
||||
),
|
||||
}
|
||||
|
||||
# Language: config.yaml (stt.local.language) > env var > "en" default.
|
||||
language = (
|
||||
_load_stt_config().get("local", {}).get("language")
|
||||
or os.getenv(LOCAL_STT_LANGUAGE_ENV)
|
||||
or DEFAULT_LOCAL_STT_LANGUAGE
|
||||
)
|
||||
language = os.getenv(LOCAL_STT_LANGUAGE_ENV, DEFAULT_LOCAL_STT_LANGUAGE)
|
||||
normalized_model = _normalize_local_command_model(model_name)
|
||||
|
||||
try:
|
||||
|
||||
@@ -2085,7 +2085,6 @@ registry.register(
|
||||
check_fn=check_web_api_key,
|
||||
requires_env=_web_requires_env(),
|
||||
emoji="🔍",
|
||||
max_result_size_chars=100_000,
|
||||
)
|
||||
registry.register(
|
||||
name="web_extract",
|
||||
@@ -2097,5 +2096,4 @@ registry.register(
|
||||
requires_env=_web_requires_env(),
|
||||
is_async=True,
|
||||
emoji="📄",
|
||||
max_result_size_chars=100_000,
|
||||
)
|
||||
|
||||
@@ -110,10 +110,6 @@ def get_config_schema(self):
|
||||
|
||||
Fields with `secret: True` and `env_var` go to `.env`. Non-secret fields are passed to `save_config()`.
|
||||
|
||||
:::tip Minimal vs Full Schema
|
||||
Every field in `get_config_schema()` is prompted during `hermes memory setup`. Providers with many options should keep the schema minimal — only include fields the user **must** configure (API key, required credentials). Document optional settings in a config file reference (e.g. `$HERMES_HOME/myprovider.json`) rather than prompting for them all during setup. This keeps the setup wizard fast while still supporting advanced configuration. See the Supermemory provider for an example — it only prompts for the API key; all other options live in `supermemory.json`.
|
||||
:::
|
||||
|
||||
## Save Config
|
||||
|
||||
```python
|
||||
|
||||
@@ -164,7 +164,6 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI
|
||||
| `TELEGRAM_WEBHOOK_URL` | Public HTTPS URL for webhook mode (enables webhook instead of polling) |
|
||||
| `TELEGRAM_WEBHOOK_PORT` | Local listen port for webhook server (default: `8443`) |
|
||||
| `TELEGRAM_WEBHOOK_SECRET` | Secret token for verifying updates come from Telegram |
|
||||
| `TELEGRAM_REACTIONS` | Enable emoji reactions on messages during processing (default: `false`) |
|
||||
| `DISCORD_BOT_TOKEN` | Discord bot token |
|
||||
| `DISCORD_ALLOWED_USERS` | Comma-separated Discord user IDs allowed to use the bot |
|
||||
| `DISCORD_HOME_CHANNEL` | Default Discord channel for cron delivery |
|
||||
@@ -172,9 +171,6 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI
|
||||
| `DISCORD_REQUIRE_MENTION` | Require an @mention before responding in server channels |
|
||||
| `DISCORD_FREE_RESPONSE_CHANNELS` | Comma-separated channel IDs where mention is not required |
|
||||
| `DISCORD_AUTO_THREAD` | Auto-thread long replies when supported |
|
||||
| `DISCORD_REACTIONS` | Enable emoji reactions on messages during processing (default: `true`) |
|
||||
| `DISCORD_IGNORED_CHANNELS` | Comma-separated channel IDs where the bot never responds |
|
||||
| `DISCORD_NO_THREAD_CHANNELS` | Comma-separated channel IDs where bot responds without auto-threading |
|
||||
| `SLACK_BOT_TOKEN` | Slack bot token (`xoxb-...`) |
|
||||
| `SLACK_APP_TOKEN` | Slack app-level token (`xapp-...`, required for Socket Mode) |
|
||||
| `SLACK_ALLOWED_USERS` | Comma-separated Slack user IDs |
|
||||
|
||||
@@ -553,7 +553,7 @@ Every model slot in Hermes — auxiliary tasks, compression, fallback — uses t
|
||||
|
||||
When `base_url` is set, Hermes ignores the provider and calls that endpoint directly (using `api_key` or `OPENAI_API_KEY` for auth). When only `provider` is set, Hermes uses that provider's built-in auth and base URL.
|
||||
|
||||
Available providers: `auto`, `openrouter`, `nous`, `codex`, `copilot`, `anthropic`, `main`, `zai`, `kimi-coding`, `minimax`, any provider registered in the [provider registry](/docs/reference/environment-variables), or any named custom provider from your `custom_providers` list (e.g. `provider: "beans"`).
|
||||
Available providers: `auto`, `openrouter`, `nous`, `codex`, `copilot`, `anthropic`, `main`, `zai`, `kimi-coding`, `minimax`, and any provider registered in the [provider registry](/docs/reference/environment-variables).
|
||||
|
||||
### Full auxiliary config reference
|
||||
|
||||
@@ -704,7 +704,7 @@ auxiliary:
|
||||
model: "my-local-model"
|
||||
```
|
||||
|
||||
`provider: "main"` uses whatever provider Hermes uses for normal chat — whether that's a named custom provider (e.g. `beans`), a built-in provider like `openrouter`, or a legacy `OPENAI_BASE_URL` endpoint.
|
||||
`provider: "main"` follows the same custom endpoint Hermes uses for normal chat. That endpoint can be set directly with `OPENAI_BASE_URL`, or saved once through `hermes model` and persisted in `config.yaml`.
|
||||
|
||||
:::tip
|
||||
If you use Codex OAuth as your main model provider, vision works automatically — no extra configuration needed. Codex is included in the auto-detection chain for vision.
|
||||
|
||||
@@ -400,47 +400,26 @@ Semantic long-term memory with profile recall, semantic search, explicit memory
|
||||
hermes memory setup # select "supermemory"
|
||||
# Or manually:
|
||||
hermes config set memory.provider supermemory
|
||||
echo 'SUPERMEMORY_API_KEY=***' >> ~/.hermes/.env
|
||||
echo 'SUPERMEMORY_API_KEY=your-key-here' >> ~/.hermes/.env
|
||||
```
|
||||
|
||||
**Config:** `$HERMES_HOME/supermemory.json`
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `container_tag` | `hermes` | Container tag used for search and writes. Supports `{identity}` template for profile-scoped tags. |
|
||||
| `container_tag` | `hermes` | Container tag used for search and writes |
|
||||
| `auto_recall` | `true` | Inject relevant memory context before turns |
|
||||
| `auto_capture` | `true` | Store cleaned user-assistant turns after each response |
|
||||
| `max_recall_results` | `10` | Max recalled items to format into context |
|
||||
| `profile_frequency` | `50` | Include profile facts on first turn and every N turns |
|
||||
| `capture_mode` | `all` | Skip tiny or trivial turns by default |
|
||||
| `search_mode` | `hybrid` | Search mode: `hybrid`, `memories`, or `documents` |
|
||||
| `api_timeout` | `5.0` | Timeout for SDK and ingest requests |
|
||||
|
||||
**Environment variables:** `SUPERMEMORY_API_KEY` (required), `SUPERMEMORY_CONTAINER_TAG` (overrides config).
|
||||
|
||||
**Key features:**
|
||||
- Automatic context fencing — strips recalled memories from captured turns to prevent recursive memory pollution
|
||||
- Session-end conversation ingest for richer graph-level knowledge building
|
||||
- Profile facts injected on first turn and at configurable intervals
|
||||
- Trivial message filtering (skips "ok", "thanks", etc.)
|
||||
- **Profile-scoped containers** — use `{identity}` in `container_tag` (e.g. `hermes-{identity}` → `hermes-coder`) to isolate memories per Hermes profile
|
||||
- **Multi-container mode** — enable `enable_custom_container_tags` with a `custom_containers` list to let the agent read/write across named containers. Automatic operations (sync, prefetch) stay on the primary container.
|
||||
|
||||
<details>
|
||||
<summary>Multi-container example</summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"container_tag": "hermes",
|
||||
"enable_custom_container_tags": true,
|
||||
"custom_containers": ["project-alpha", "shared-knowledge"],
|
||||
"custom_container_instructions": "Use project-alpha for coding context."
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
**Support:** [Discord](https://supermemory.link/discord) · [support@supermemory.com](mailto:support@supermemory.com)
|
||||
|
||||
---
|
||||
|
||||
@@ -455,7 +434,7 @@ echo 'SUPERMEMORY_API_KEY=***' >> ~/.hermes/.env
|
||||
| **Holographic** | Local | Free | 2 | None | HRR algebra + trust scoring |
|
||||
| **RetainDB** | Cloud | $20/mo | 5 | `requests` | Delta compression |
|
||||
| **ByteRover** | Local/Cloud | Free/Paid | 3 | `brv` CLI | Pre-compression extraction |
|
||||
| **Supermemory** | Cloud | Paid | 4 | `supermemory` | Context fencing + session graph ingest + multi-container |
|
||||
| **Supermemory** | Cloud | Paid | 4 | `supermemory` | Context fencing + session graph ingest |
|
||||
|
||||
## Profile Isolation
|
||||
|
||||
|
||||
@@ -280,8 +280,6 @@ Discord behavior is controlled through two files: **`~/.hermes/.env`** for crede
|
||||
| `DISCORD_AUTO_THREAD` | No | `true` | When `true`, automatically creates a new thread for every `@mention` in a text channel, so each conversation is isolated (similar to Slack behavior). Messages already inside threads or DMs are unaffected. |
|
||||
| `DISCORD_ALLOW_BOTS` | No | `"none"` | Controls how the bot handles messages from other Discord bots. `"none"` — ignore all other bots. `"mentions"` — only accept bot messages that `@mention` Hermes. `"all"` — accept all bot messages. |
|
||||
| `DISCORD_REACTIONS` | No | `true` | When `true`, the bot adds emoji reactions to messages during processing (👀 when starting, ✅ on success, ❌ on error). Set to `false` to disable reactions entirely. |
|
||||
| `DISCORD_IGNORED_CHANNELS` | No | — | Comma-separated channel IDs where the bot **never** responds, even when `@mentioned`. Takes priority over all other channel settings. |
|
||||
| `DISCORD_NO_THREAD_CHANNELS` | No | — | Comma-separated channel IDs where the bot responds directly in the channel instead of creating a thread. Only relevant when `DISCORD_AUTO_THREAD` is `true`. |
|
||||
|
||||
### Config File (`config.yaml`)
|
||||
|
||||
@@ -294,8 +292,6 @@ discord:
|
||||
free_response_channels: "" # Comma-separated channel IDs (or YAML list)
|
||||
auto_thread: true # Auto-create threads on @mention
|
||||
reactions: true # Add emoji reactions during processing
|
||||
ignored_channels: [] # Channel IDs where bot never responds
|
||||
no_thread_channels: [] # Channel IDs where bot responds without threading
|
||||
|
||||
# Session isolation (applies to all gateway platforms, not just Discord)
|
||||
group_sessions_per_user: true # Isolate sessions per user in shared channels
|
||||
@@ -346,40 +342,6 @@ Controls whether the bot adds emoji reactions to messages as visual feedback:
|
||||
|
||||
Disable this if you find the reactions distracting or if the bot's role doesn't have the **Add Reactions** permission.
|
||||
|
||||
#### `discord.ignored_channels`
|
||||
|
||||
**Type:** string or list — **Default:** `[]`
|
||||
|
||||
Channel IDs where the bot **never** responds, even when directly `@mentioned`. This takes the highest priority — if a channel is in this list, the bot silently ignores all messages there, regardless of `require_mention`, `free_response_channels`, or any other setting.
|
||||
|
||||
```yaml
|
||||
# String format
|
||||
discord:
|
||||
ignored_channels: "1234567890,9876543210"
|
||||
|
||||
# List format
|
||||
discord:
|
||||
ignored_channels:
|
||||
- 1234567890
|
||||
- 9876543210
|
||||
```
|
||||
|
||||
If a thread's parent channel is in this list, messages in that thread are also ignored.
|
||||
|
||||
#### `discord.no_thread_channels`
|
||||
|
||||
**Type:** string or list — **Default:** `[]`
|
||||
|
||||
Channel IDs where the bot responds directly in the channel instead of auto-creating a thread. This only has an effect when `auto_thread` is `true` (the default). In these channels, the bot responds inline like a normal message rather than spawning a new thread.
|
||||
|
||||
```yaml
|
||||
discord:
|
||||
no_thread_channels:
|
||||
- 1234567890 # Bot responds inline here
|
||||
```
|
||||
|
||||
Useful for channels dedicated to bot interaction where threads would add unnecessary noise.
|
||||
|
||||
#### `group_sessions_per_user`
|
||||
|
||||
**Type:** boolean — **Default:** `true`
|
||||
|
||||
@@ -463,35 +463,6 @@ platforms:
|
||||
You usually don't need to configure this manually. The auto-discovery via DoH handles most restricted-network scenarios. The `TELEGRAM_FALLBACK_IPS` env var is only needed if DoH is also blocked on your network.
|
||||
:::
|
||||
|
||||
## Message Reactions
|
||||
|
||||
The bot can add emoji reactions to messages as visual processing feedback:
|
||||
|
||||
- 👀 when the bot starts processing your message
|
||||
- ✅ when the response is delivered successfully
|
||||
- ❌ if an error occurs during processing
|
||||
|
||||
Reactions are **disabled by default**. Enable them in `config.yaml`:
|
||||
|
||||
```yaml
|
||||
telegram:
|
||||
reactions: true
|
||||
```
|
||||
|
||||
Or via environment variable:
|
||||
|
||||
```bash
|
||||
TELEGRAM_REACTIONS=true
|
||||
```
|
||||
|
||||
:::note
|
||||
Unlike Discord (where reactions are additive), Telegram's Bot API replaces all bot reactions in a single call. The transition from 👀 to ✅/❌ happens atomically — you won't see both at once.
|
||||
:::
|
||||
|
||||
:::tip
|
||||
If the bot doesn't have permission to add reactions in a group, the reaction calls fail silently and message processing continues normally.
|
||||
:::
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Solution |
|
||||
|
||||
Reference in New Issue
Block a user