Compare commits
10 Commits
fix/toolse
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9381272cfc | ||
|
|
48b5bc6038 | ||
|
|
4ff73fb32c | ||
|
|
73a88a02fe | ||
|
|
f9c2565ab4 | ||
|
|
ad5f973a8d | ||
|
|
0791efe2c3 | ||
|
|
934fbe3c06 | ||
|
|
6302e56e7c | ||
|
|
868b3c07e3 |
6
cli.py
6
cli.py
@@ -301,7 +301,11 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
defaults["agent"]["max_turns"] = file_config["max_turns"]
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load cli-config.yaml: %s", e)
|
||||
|
||||
|
||||
# Expand ${ENV_VAR} references in config values before bridging to env vars.
|
||||
from hermes_cli.config import _expand_env_vars
|
||||
defaults = _expand_env_vars(defaults)
|
||||
|
||||
# Apply terminal config to environment variables (so terminal_tool picks them up)
|
||||
terminal_config = defaults.get("terminal", {})
|
||||
|
||||
|
||||
@@ -523,8 +523,13 @@ def load_gateway_config() -> GatewayConfig:
|
||||
os.environ["DISCORD_FREE_RESPONSE_CHANNELS"] = str(frc)
|
||||
if "auto_thread" in discord_cfg and not os.getenv("DISCORD_AUTO_THREAD"):
|
||||
os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to process config.yaml — falling back to .env / gateway.json values. "
|
||||
"Check %s for syntax errors. Error: %s",
|
||||
_home / "config.yaml",
|
||||
e,
|
||||
)
|
||||
|
||||
config = GatewayConfig.from_dict(gw_data)
|
||||
|
||||
|
||||
@@ -93,6 +93,9 @@ if _config_path.exists():
|
||||
import yaml as _yaml
|
||||
with open(_config_path, encoding="utf-8") as _f:
|
||||
_cfg = _yaml.safe_load(_f) or {}
|
||||
# Expand ${ENV_VAR} references before bridging to env vars.
|
||||
from hermes_cli.config import _expand_env_vars
|
||||
_cfg = _expand_env_vars(_cfg)
|
||||
# Top-level simple values (fallback only — don't override .env)
|
||||
for _key, _val in _cfg.items():
|
||||
if isinstance(_val, (str, int, float, bool)) and _key not in os.environ:
|
||||
@@ -525,6 +528,12 @@ class GatewayRunner:
|
||||
Synchronous worker — meant to be called via run_in_executor from
|
||||
an async context so it doesn't block the event loop.
|
||||
"""
|
||||
# Skip cron sessions — they run headless with no meaningful user
|
||||
# conversation to extract memories from.
|
||||
if old_session_id and old_session_id.startswith("cron_"):
|
||||
logger.debug("Skipping memory flush for cron session: %s", old_session_id)
|
||||
return
|
||||
|
||||
try:
|
||||
history = self.session_store.load_transcript(old_session_id)
|
||||
if not history or len(history) < 4:
|
||||
@@ -557,6 +566,23 @@ class GatewayRunner:
|
||||
if m.get("role") in ("user", "assistant") and m.get("content")
|
||||
]
|
||||
|
||||
# Read live memory state from disk so the flush agent can see
|
||||
# what's already saved and avoid overwriting newer entries.
|
||||
_current_memory = ""
|
||||
try:
|
||||
from tools.memory_tool import MEMORY_DIR
|
||||
for fname, label in [
|
||||
("MEMORY.md", "MEMORY (your personal notes)"),
|
||||
("USER.md", "USER PROFILE (who the user is)"),
|
||||
]:
|
||||
fpath = MEMORY_DIR / fname
|
||||
if fpath.exists():
|
||||
content = fpath.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
_current_memory += f"\n\n## Current {label}:\n{content}"
|
||||
except Exception:
|
||||
pass # Non-fatal — flush still works, just without the guard
|
||||
|
||||
# Give the agent a real turn to think about what to save
|
||||
flush_prompt = (
|
||||
"[System: This session is about to be automatically reset due to "
|
||||
@@ -568,6 +594,20 @@ class GatewayRunner:
|
||||
"2. If you discovered a reusable workflow or solved a non-trivial "
|
||||
"problem, consider saving it as a skill.\n"
|
||||
"3. If nothing is worth saving, that's fine — just skip.\n\n"
|
||||
)
|
||||
|
||||
if _current_memory:
|
||||
flush_prompt += (
|
||||
"IMPORTANT — here is the current live state of memory. Other "
|
||||
"sessions, cron jobs, or the user may have updated it since this "
|
||||
"conversation ended. Do NOT overwrite or remove entries unless "
|
||||
"the conversation above reveals something that genuinely "
|
||||
"supersedes them. Only add new information that is not already "
|
||||
"captured below."
|
||||
f"{_current_memory}\n\n"
|
||||
)
|
||||
|
||||
flush_prompt += (
|
||||
"Do NOT respond to the user. Just use the memory and skill_manage "
|
||||
"tools if needed, then stop.]"
|
||||
)
|
||||
@@ -904,7 +944,9 @@ class GatewayRunner:
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")
|
||||
|
||||
@@ -1172,6 +1172,26 @@ def _deep_merge(base: dict, override: dict) -> dict:
|
||||
return result
|
||||
|
||||
|
||||
def _expand_env_vars(obj):
|
||||
"""Recursively expand ``${VAR}`` references in config values.
|
||||
|
||||
Only string values are processed; dict keys, numbers, booleans, and
|
||||
None are left untouched. Unresolved references (variable not in
|
||||
``os.environ``) are kept verbatim so callers can detect them.
|
||||
"""
|
||||
if isinstance(obj, str):
|
||||
return re.sub(
|
||||
r"\${([^}]+)}",
|
||||
lambda m: os.environ.get(m.group(1), m.group(0)),
|
||||
obj,
|
||||
)
|
||||
if isinstance(obj, dict):
|
||||
return {k: _expand_env_vars(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_expand_env_vars(item) for item in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def _normalize_max_turns_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Normalize legacy root-level max_turns into agent.max_turns."""
|
||||
config = dict(config)
|
||||
@@ -1213,7 +1233,7 @@ def load_config() -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load config: {e}")
|
||||
|
||||
return _normalize_max_turns_config(config)
|
||||
return _expand_env_vars(_normalize_max_turns_config(config))
|
||||
|
||||
|
||||
_SECURITY_COMMENT = """
|
||||
|
||||
@@ -391,18 +391,29 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
default_ts = PLATFORMS[platform]["default_toolset"]
|
||||
toolset_names = [default_ts]
|
||||
|
||||
# Resolve to individual tool names, then map back to which
|
||||
# configurable toolsets are covered
|
||||
all_tool_names = set()
|
||||
for ts_name in toolset_names:
|
||||
all_tool_names.update(resolve_toolset(ts_name))
|
||||
configurable_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
|
||||
# Map individual tool names back to configurable toolset keys
|
||||
enabled_toolsets = set()
|
||||
for ts_key, _, _ in CONFIGURABLE_TOOLSETS:
|
||||
ts_tools = set(resolve_toolset(ts_key))
|
||||
if ts_tools and ts_tools.issubset(all_tool_names):
|
||||
enabled_toolsets.add(ts_key)
|
||||
# If the saved list contains any configurable keys directly, the user
|
||||
# has explicitly configured this platform — use direct membership.
|
||||
# This avoids the subset-inference bug where composite toolsets like
|
||||
# "hermes-cli" (which include all _HERMES_CORE_TOOLS) cause disabled
|
||||
# toolsets to re-appear as enabled.
|
||||
has_explicit_config = any(ts in configurable_keys for ts in toolset_names)
|
||||
|
||||
if has_explicit_config:
|
||||
enabled_toolsets = {ts for ts in toolset_names if ts in configurable_keys}
|
||||
else:
|
||||
# No explicit config — fall back to resolving composite toolset names
|
||||
# (e.g. "hermes-cli") to individual tool names and reverse-mapping.
|
||||
all_tool_names = set()
|
||||
for ts_name in toolset_names:
|
||||
all_tool_names.update(resolve_toolset(ts_name))
|
||||
|
||||
enabled_toolsets = set()
|
||||
for ts_key, _, _ in CONFIGURABLE_TOOLSETS:
|
||||
ts_tools = set(resolve_toolset(ts_key))
|
||||
if ts_tools and ts_tools.issubset(all_tool_names):
|
||||
enabled_toolsets.add(ts_key)
|
||||
|
||||
# Plugin toolsets: enabled by default unless explicitly disabled.
|
||||
# A plugin toolset is "known" for a platform once `hermes tools`
|
||||
@@ -437,15 +448,21 @@ def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[
|
||||
plugin_keys = _get_plugin_toolset_keys()
|
||||
configurable_keys |= plugin_keys
|
||||
|
||||
# Also exclude platform default toolsets (hermes-cli, hermes-telegram, etc.)
|
||||
# These are "super" toolsets that resolve to ALL tools, so preserving them
|
||||
# would silently override the user's unchecked selections on the next read.
|
||||
platform_default_keys = {p["default_toolset"] for p in PLATFORMS.values()}
|
||||
|
||||
# Get existing toolsets for this platform
|
||||
existing_toolsets = config.get("platform_toolsets", {}).get(platform, [])
|
||||
if not isinstance(existing_toolsets, list):
|
||||
existing_toolsets = []
|
||||
|
||||
# Preserve any entries that are NOT configurable toolsets (i.e. MCP server names)
|
||||
# Preserve any entries that are NOT configurable toolsets and NOT platform
|
||||
# defaults (i.e. only MCP server names should be preserved)
|
||||
preserved_entries = {
|
||||
entry for entry in existing_toolsets
|
||||
if entry not in configurable_keys
|
||||
if entry not in configurable_keys and entry not in platform_default_keys
|
||||
}
|
||||
|
||||
# Merge preserved entries with new enabled toolsets
|
||||
|
||||
167
tests/gateway/test_flush_memory_stale_guard.py
Normal file
167
tests/gateway/test_flush_memory_stale_guard.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Tests for memory flush stale-overwrite prevention (#2670).
|
||||
|
||||
Verifies that:
|
||||
1. Cron sessions are skipped (no flush for headless cron runs)
|
||||
2. Current memory state is injected into the flush prompt so the
|
||||
flush agent can see what's already saved and avoid overwrites
|
||||
3. The flush still works normally when memory files don't exist
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._honcho_managers = {}
|
||||
runner._honcho_configs = {}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner.adapters = {}
|
||||
runner.hooks = MagicMock()
|
||||
runner.session_store = MagicMock()
|
||||
return runner
|
||||
|
||||
|
||||
_TRANSCRIPT_4_MSGS = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
{"role": "user", "content": "remember my name is Alice"},
|
||||
{"role": "assistant", "content": "Got it, Alice!"},
|
||||
]
|
||||
|
||||
|
||||
class TestCronSessionBypass:
|
||||
"""Cron sessions should never trigger a memory flush."""
|
||||
|
||||
def test_cron_session_skipped(self):
|
||||
runner = _make_runner()
|
||||
runner._flush_memories_for_session("cron_job123_20260323_120000")
|
||||
# session_store.load_transcript should never be called
|
||||
runner.session_store.load_transcript.assert_not_called()
|
||||
|
||||
def test_cron_session_with_honcho_key_skipped(self):
|
||||
runner = _make_runner()
|
||||
runner._flush_memories_for_session("cron_daily_20260323", "some-honcho-key")
|
||||
runner.session_store.load_transcript.assert_not_called()
|
||||
|
||||
def test_non_cron_session_proceeds(self):
|
||||
"""Non-cron sessions should still attempt the flush."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner._flush_memories_for_session("session_abc123")
|
||||
runner.session_store.load_transcript.assert_called_once_with("session_abc123")
|
||||
|
||||
|
||||
class TestMemoryInjection:
|
||||
"""The flush prompt should include current memory state from disk."""
|
||||
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path):
|
||||
"""When memory files exist, their content appears in the flush prompt."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode")
|
||||
(memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST")
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Intercept `from tools.memory_tool import MEMORY_DIR` inside the function
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_123")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
call_kwargs = tmp_agent.run_conversation.call_args.kwargs
|
||||
flush_prompt = call_kwargs.get("user_message", "")
|
||||
|
||||
# Verify both memory sections appear in the prompt
|
||||
assert "Agent knows Python" in flush_prompt
|
||||
assert "User prefers dark mode" in flush_prompt
|
||||
assert "Name: Alice" in flush_prompt
|
||||
assert "Timezone: PST" in flush_prompt
|
||||
# Verify the stale-overwrite warning is present
|
||||
assert "Do NOT overwrite or remove entries" in flush_prompt
|
||||
assert "current live state of memory" in flush_prompt
|
||||
|
||||
def test_flush_works_without_memory_files(self, tmp_path):
|
||||
"""When no memory files exist, flush still runs without the guard."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
empty_dir = tmp_path / "no_memories"
|
||||
empty_dir.mkdir()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_456")
|
||||
|
||||
# Should still run, just without the memory guard section
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "Do NOT overwrite or remove entries" not in flush_prompt
|
||||
assert "Review the conversation above" in flush_prompt
|
||||
|
||||
def test_empty_memory_files_no_injection(self, tmp_path):
|
||||
"""Empty memory files should not trigger the guard section."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("")
|
||||
(memory_dir / "USER.md").write_text(" \n ") # whitespace only
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_789")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
# No memory content → no guard section
|
||||
assert "current live state of memory" not in flush_prompt
|
||||
|
||||
|
||||
class TestFlushPromptStructure:
|
||||
"""Verify the flush prompt retains its core instructions."""
|
||||
|
||||
def test_core_instructions_present(self):
|
||||
"""The flush prompt should still contain the original guidance."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Make the import fail gracefully so we test without memory files
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_struct")
|
||||
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "automatically reset" in flush_prompt
|
||||
assert "Save any important facts" in flush_prompt
|
||||
assert "consider saving it as a skill" in flush_prompt
|
||||
assert "Do NOT respond to the user" in flush_prompt
|
||||
@@ -100,3 +100,107 @@ def test_save_platform_tools_handles_invalid_existing_config():
|
||||
|
||||
saved_toolsets = config["platform_toolsets"]["cli"]
|
||||
assert "web" in saved_toolsets
|
||||
|
||||
|
||||
def test_save_platform_tools_does_not_preserve_platform_default_toolsets():
|
||||
"""Platform default toolsets (hermes-cli, hermes-telegram, etc.) must NOT
|
||||
be preserved across saves.
|
||||
|
||||
These "super" toolsets resolve to ALL tools, so if they survive in the
|
||||
config, they silently override any tools the user unchecked. Previously,
|
||||
the preserve filter only excluded configurable toolset keys (web, browser,
|
||||
terminal, etc.) and treated platform defaults as unknown custom entries
|
||||
(like MCP server names), causing them to be kept unconditionally.
|
||||
|
||||
Regression test: user unchecks image_gen and homeassistant via
|
||||
``hermes tools``, but hermes-cli stays in the config and re-enables
|
||||
everything on the next read.
|
||||
"""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"cli": [
|
||||
"browser", "clarify", "code_execution", "cronjob",
|
||||
"delegation", "file", "hermes-cli", # <-- the culprit
|
||||
"memory", "session_search", "skills", "terminal",
|
||||
"todo", "tts", "vision", "web",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# User unchecks image_gen, homeassistant, moa — keeps the rest
|
||||
new_selection = {
|
||||
"browser", "clarify", "code_execution", "cronjob",
|
||||
"delegation", "file", "memory", "session_search",
|
||||
"skills", "terminal", "todo", "tts", "vision", "web",
|
||||
}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "cli", new_selection)
|
||||
|
||||
saved = config["platform_toolsets"]["cli"]
|
||||
|
||||
# hermes-cli must NOT survive — it's a platform default, not an MCP server
|
||||
assert "hermes-cli" not in saved
|
||||
|
||||
# The individual toolset keys the user selected must be present
|
||||
assert "web" in saved
|
||||
assert "terminal" in saved
|
||||
assert "browser" in saved
|
||||
|
||||
# Tools the user unchecked must NOT be present
|
||||
assert "image_gen" not in saved
|
||||
assert "homeassistant" not in saved
|
||||
assert "moa" not in saved
|
||||
|
||||
|
||||
def test_save_platform_tools_does_not_preserve_hermes_telegram():
|
||||
"""Same bug for Telegram — hermes-telegram must not be preserved."""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"telegram": [
|
||||
"browser", "file", "hermes-telegram", "terminal", "web",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
new_selection = {"browser", "file", "terminal", "web"}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "telegram", new_selection)
|
||||
|
||||
saved = config["platform_toolsets"]["telegram"]
|
||||
assert "hermes-telegram" not in saved
|
||||
assert "web" in saved
|
||||
|
||||
|
||||
def test_save_platform_tools_still_preserves_mcp_with_platform_default_present():
|
||||
"""MCP server names must still be preserved even when platform defaults
|
||||
are being stripped out."""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"cli": [
|
||||
"web", "terminal", "hermes-cli", "my-mcp-server", "github-tools",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
new_selection = {"web", "browser"}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "cli", new_selection)
|
||||
|
||||
saved = config["platform_toolsets"]["cli"]
|
||||
|
||||
# MCP servers preserved
|
||||
assert "my-mcp-server" in saved
|
||||
assert "github-tools" in saved
|
||||
|
||||
# Platform default stripped
|
||||
assert "hermes-cli" not in saved
|
||||
|
||||
# User selections present
|
||||
assert "web" in saved
|
||||
assert "browser" in saved
|
||||
|
||||
# Deselected configurable toolset removed
|
||||
assert "terminal" not in saved
|
||||
|
||||
132
tests/test_config_env_expansion.py
Normal file
132
tests/test_config_env_expansion.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for ${ENV_VAR} substitution in config.yaml values."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from hermes_cli.config import _expand_env_vars, load_config
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
||||
|
||||
class TestExpandEnvVars:
|
||||
def test_simple_substitution(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("MY_KEY", "secret123")
|
||||
assert _expand_env_vars("${MY_KEY}") == "secret123"
|
||||
|
||||
def test_missing_var_kept_verbatim(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.delenv("UNDEFINED_VAR_XYZ", raising=False)
|
||||
assert _expand_env_vars("${UNDEFINED_VAR_XYZ}") == "${UNDEFINED_VAR_XYZ}"
|
||||
|
||||
def test_no_placeholder_unchanged(self):
|
||||
assert _expand_env_vars("plain-value") == "plain-value"
|
||||
|
||||
def test_dict_recursive(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("TOKEN", "tok-abc")
|
||||
result = _expand_env_vars({"key": "${TOKEN}", "other": "literal"})
|
||||
assert result == {"key": "tok-abc", "other": "literal"}
|
||||
|
||||
def test_nested_dict(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("API_KEY", "sk-xyz")
|
||||
result = _expand_env_vars({"model": {"api_key": "${API_KEY}"}})
|
||||
assert result["model"]["api_key"] == "sk-xyz"
|
||||
|
||||
def test_list_items(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("VAL", "hello")
|
||||
result = _expand_env_vars(["${VAL}", "literal", 42])
|
||||
assert result == ["hello", "literal", 42]
|
||||
|
||||
def test_non_string_values_untouched(self):
|
||||
assert _expand_env_vars(42) == 42
|
||||
assert _expand_env_vars(3.14) == 3.14
|
||||
assert _expand_env_vars(True) is True
|
||||
assert _expand_env_vars(None) is None
|
||||
|
||||
def test_multiple_placeholders_in_one_string(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("HOST", "localhost")
|
||||
mp.setenv("PORT", "5432")
|
||||
assert _expand_env_vars("${HOST}:${PORT}") == "localhost:5432"
|
||||
|
||||
def test_dict_keys_not_expanded(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("KEY", "value")
|
||||
result = _expand_env_vars({"${KEY}": "no-expand-key"})
|
||||
assert "${KEY}" in result
|
||||
|
||||
|
||||
class TestLoadConfigExpansion:
|
||||
def test_load_config_expands_env_vars(self, tmp_path, monkeypatch):
|
||||
config_yaml = (
|
||||
"model:\n"
|
||||
" api_key: ${GOOGLE_API_KEY}\n"
|
||||
"platforms:\n"
|
||||
" telegram:\n"
|
||||
" token: ${TELEGRAM_BOT_TOKEN}\n"
|
||||
"plain: no-substitution\n"
|
||||
)
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(config_yaml)
|
||||
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "gsk-test-key")
|
||||
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "1234567:ABC-token")
|
||||
monkeypatch.setattr("hermes_cli.config.get_config_path", lambda: config_file)
|
||||
|
||||
config = load_config()
|
||||
|
||||
assert config["model"]["api_key"] == "gsk-test-key"
|
||||
assert config["platforms"]["telegram"]["token"] == "1234567:ABC-token"
|
||||
assert config["plain"] == "no-substitution"
|
||||
|
||||
def test_load_config_unresolved_kept_verbatim(self, tmp_path, monkeypatch):
|
||||
config_yaml = "model:\n api_key: ${NOT_SET_XYZ_123}\n"
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(config_yaml)
|
||||
|
||||
monkeypatch.delenv("NOT_SET_XYZ_123", raising=False)
|
||||
monkeypatch.setattr("hermes_cli.config.get_config_path", lambda: config_file)
|
||||
|
||||
config = load_config()
|
||||
|
||||
assert config["model"]["api_key"] == "${NOT_SET_XYZ_123}"
|
||||
|
||||
|
||||
class TestLoadCliConfigExpansion:
|
||||
"""Verify that load_cli_config() also expands ${VAR} references."""
|
||||
|
||||
def test_cli_config_expands_auxiliary_api_key(self, tmp_path, monkeypatch):
|
||||
config_yaml = (
|
||||
"auxiliary:\n"
|
||||
" vision:\n"
|
||||
" api_key: ${TEST_VISION_KEY_XYZ}\n"
|
||||
)
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(config_yaml)
|
||||
|
||||
monkeypatch.setenv("TEST_VISION_KEY_XYZ", "vis-key-123")
|
||||
# Patch the hermes home so load_cli_config finds our test config
|
||||
monkeypatch.setattr("cli._hermes_home", tmp_path)
|
||||
|
||||
from cli import load_cli_config
|
||||
config = load_cli_config()
|
||||
|
||||
assert config["auxiliary"]["vision"]["api_key"] == "vis-key-123"
|
||||
|
||||
def test_cli_config_unresolved_kept_verbatim(self, tmp_path, monkeypatch):
|
||||
config_yaml = (
|
||||
"auxiliary:\n"
|
||||
" vision:\n"
|
||||
" api_key: ${UNSET_CLI_VAR_ABC}\n"
|
||||
)
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(config_yaml)
|
||||
|
||||
monkeypatch.delenv("UNSET_CLI_VAR_ABC", raising=False)
|
||||
monkeypatch.setattr("cli._hermes_home", tmp_path)
|
||||
|
||||
from cli import load_cli_config
|
||||
config = load_cli_config()
|
||||
|
||||
assert config["auxiliary"]["vision"]["api_key"] == "${UNSET_CLI_VAR_ABC}"
|
||||
168
tests/tools/test_ansi_strip.py
Normal file
168
tests/tools/test_ansi_strip.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Comprehensive tests for ANSI escape sequence stripping (ECMA-48).
|
||||
|
||||
The strip_ansi function in tools/ansi_strip.py is the source-level fix for
|
||||
ANSI codes leaking into the model's context via terminal/execute_code output.
|
||||
It must strip ALL terminal escape sequences while preserving legitimate text.
|
||||
"""
|
||||
|
||||
from tools.ansi_strip import strip_ansi
|
||||
|
||||
|
||||
class TestStripAnsiBasicSGR:
|
||||
"""Select Graphic Rendition — the most common ANSI sequences."""
|
||||
|
||||
def test_reset(self):
|
||||
assert strip_ansi("\x1b[0m") == ""
|
||||
|
||||
def test_color(self):
|
||||
assert strip_ansi("\x1b[31;1m") == ""
|
||||
|
||||
def test_truecolor_semicolon(self):
|
||||
assert strip_ansi("\x1b[38;2;255;0;0m") == ""
|
||||
|
||||
def test_truecolor_colon_separated(self):
|
||||
"""Modern terminals use colon-separated SGR params."""
|
||||
assert strip_ansi("\x1b[38:2:255:0:0m") == ""
|
||||
assert strip_ansi("\x1b[48:2:0:255:0m") == ""
|
||||
|
||||
|
||||
class TestStripAnsiCSIPrivateMode:
|
||||
"""CSI sequences with ? prefix (DEC private modes)."""
|
||||
|
||||
def test_cursor_show_hide(self):
|
||||
assert strip_ansi("\x1b[?25h") == ""
|
||||
assert strip_ansi("\x1b[?25l") == ""
|
||||
|
||||
def test_alt_screen(self):
|
||||
assert strip_ansi("\x1b[?1049h") == ""
|
||||
assert strip_ansi("\x1b[?1049l") == ""
|
||||
|
||||
def test_bracketed_paste(self):
|
||||
assert strip_ansi("\x1b[?2004h") == ""
|
||||
|
||||
|
||||
class TestStripAnsiCSIIntermediate:
|
||||
"""CSI sequences with intermediate bytes (space, etc.)."""
|
||||
|
||||
def test_cursor_shape(self):
|
||||
assert strip_ansi("\x1b[0 q") == ""
|
||||
assert strip_ansi("\x1b[2 q") == ""
|
||||
assert strip_ansi("\x1b[6 q") == ""
|
||||
|
||||
|
||||
class TestStripAnsiOSC:
|
||||
"""Operating System Command sequences."""
|
||||
|
||||
def test_bel_terminator(self):
|
||||
assert strip_ansi("\x1b]0;title\x07") == ""
|
||||
|
||||
def test_st_terminator(self):
|
||||
assert strip_ansi("\x1b]0;title\x1b\\") == ""
|
||||
|
||||
def test_hyperlink_preserves_text(self):
|
||||
assert strip_ansi(
|
||||
"\x1b]8;;https://example.com\x1b\\click\x1b]8;;\x1b\\"
|
||||
) == "click"
|
||||
|
||||
|
||||
class TestStripAnsiDECPrivate:
|
||||
"""DEC private / Fp escape sequences."""
|
||||
|
||||
def test_save_restore_cursor(self):
|
||||
assert strip_ansi("\x1b7") == ""
|
||||
assert strip_ansi("\x1b8") == ""
|
||||
|
||||
def test_keypad_modes(self):
|
||||
assert strip_ansi("\x1b=") == ""
|
||||
assert strip_ansi("\x1b>") == ""
|
||||
|
||||
|
||||
class TestStripAnsiFe:
|
||||
"""Fe (C1 as 7-bit) escape sequences."""
|
||||
|
||||
def test_reverse_index(self):
|
||||
assert strip_ansi("\x1bM") == ""
|
||||
|
||||
def test_reset_terminal(self):
|
||||
assert strip_ansi("\x1bc") == ""
|
||||
|
||||
def test_index_and_newline(self):
|
||||
assert strip_ansi("\x1bD") == ""
|
||||
assert strip_ansi("\x1bE") == ""
|
||||
|
||||
|
||||
class TestStripAnsiNF:
|
||||
"""nF (character set selection) sequences."""
|
||||
|
||||
def test_charset_selection(self):
|
||||
assert strip_ansi("\x1b(A") == ""
|
||||
assert strip_ansi("\x1b(B") == ""
|
||||
assert strip_ansi("\x1b(0") == ""
|
||||
|
||||
|
||||
class TestStripAnsiDCS:
|
||||
"""Device Control String sequences."""
|
||||
|
||||
def test_dcs(self):
|
||||
assert strip_ansi("\x1bP+q\x1b\\") == ""
|
||||
|
||||
|
||||
class TestStripAnsi8BitC1:
|
||||
"""8-bit C1 control characters."""
|
||||
|
||||
def test_8bit_csi(self):
|
||||
assert strip_ansi("\x9b31m") == ""
|
||||
assert strip_ansi("\x9b38;2;255;0;0m") == ""
|
||||
|
||||
def test_8bit_standalone(self):
|
||||
assert strip_ansi("\x9c") == ""
|
||||
assert strip_ansi("\x9d") == ""
|
||||
assert strip_ansi("\x90") == ""
|
||||
|
||||
|
||||
class TestStripAnsiRealWorld:
|
||||
"""Real-world contamination scenarios from bug reports."""
|
||||
|
||||
def test_colored_shebang(self):
|
||||
"""The original reported bug: shebang corrupted by color codes."""
|
||||
assert strip_ansi(
|
||||
"\x1b[32m#!/usr/bin/env python3\x1b[0m\nprint('hello')"
|
||||
) == "#!/usr/bin/env python3\nprint('hello')"
|
||||
|
||||
def test_stacked_sgr(self):
|
||||
assert strip_ansi(
|
||||
"\x1b[1m\x1b[31m\x1b[42mhello\x1b[0m"
|
||||
) == "hello"
|
||||
|
||||
def test_ansi_mid_code(self):
|
||||
assert strip_ansi(
|
||||
"def foo(\x1b[33m):\x1b[0m\n return 42"
|
||||
) == "def foo():\n return 42"
|
||||
|
||||
|
||||
class TestStripAnsiPassthrough:
|
||||
"""Clean content must pass through unmodified."""
|
||||
|
||||
def test_plain_text(self):
|
||||
assert strip_ansi("normal text") == "normal text"
|
||||
|
||||
def test_empty(self):
|
||||
assert strip_ansi("") == ""
|
||||
|
||||
def test_none(self):
|
||||
assert strip_ansi(None) is None
|
||||
|
||||
def test_whitespace_preserved(self):
|
||||
assert strip_ansi("line1\nline2\ttab") == "line1\nline2\ttab"
|
||||
|
||||
def test_unicode_safe(self):
|
||||
assert strip_ansi("emoji 🎉 and ñ café") == "emoji 🎉 and ñ café"
|
||||
|
||||
def test_backslash_in_code(self):
|
||||
code = "path = 'C:\\\\Users\\\\test'"
|
||||
assert strip_ansi(code) == code
|
||||
|
||||
def test_square_brackets_in_code(self):
|
||||
"""Array indexing must not be confused with CSI."""
|
||||
code = "arr[0] = arr[31]"
|
||||
assert strip_ansi(code) == code
|
||||
259
tests/tools/test_browser_homebrew_paths.py
Normal file
259
tests/tools/test_browser_homebrew_paths.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""Tests for macOS Homebrew PATH discovery in browser_tool.py."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, mock_open
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.browser_tool import (
|
||||
_discover_homebrew_node_dirs,
|
||||
_find_agent_browser,
|
||||
_run_browser_command,
|
||||
_SANE_PATH,
|
||||
)
|
||||
|
||||
|
||||
class TestSanePath:
|
||||
"""Verify _SANE_PATH includes Homebrew directories."""
|
||||
|
||||
def test_includes_homebrew_bin(self):
|
||||
assert "/opt/homebrew/bin" in _SANE_PATH
|
||||
|
||||
def test_includes_homebrew_sbin(self):
|
||||
assert "/opt/homebrew/sbin" in _SANE_PATH
|
||||
|
||||
def test_includes_standard_dirs(self):
|
||||
assert "/usr/local/bin" in _SANE_PATH
|
||||
assert "/usr/bin" in _SANE_PATH
|
||||
assert "/bin" in _SANE_PATH
|
||||
|
||||
|
||||
class TestDiscoverHomebrewNodeDirs:
|
||||
"""Tests for _discover_homebrew_node_dirs()."""
|
||||
|
||||
def test_returns_empty_when_no_homebrew(self):
|
||||
"""Non-macOS systems without /opt/homebrew/opt should return empty."""
|
||||
with patch("os.path.isdir", return_value=False):
|
||||
assert _discover_homebrew_node_dirs() == []
|
||||
|
||||
def test_finds_versioned_node_dirs(self):
|
||||
"""Should discover node@20/bin, node@24/bin etc."""
|
||||
entries = ["node@20", "node@24", "openssl", "node", "python@3.12"]
|
||||
|
||||
def mock_isdir(p):
|
||||
if p == "/opt/homebrew/opt":
|
||||
return True
|
||||
# node@20/bin and node@24/bin exist
|
||||
if p in (
|
||||
"/opt/homebrew/opt/node@20/bin",
|
||||
"/opt/homebrew/opt/node@24/bin",
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
with patch("os.path.isdir", side_effect=mock_isdir), \
|
||||
patch("os.listdir", return_value=entries):
|
||||
result = _discover_homebrew_node_dirs()
|
||||
|
||||
assert len(result) == 2
|
||||
assert "/opt/homebrew/opt/node@20/bin" in result
|
||||
assert "/opt/homebrew/opt/node@24/bin" in result
|
||||
|
||||
def test_excludes_plain_node(self):
|
||||
"""'node' (unversioned) should be excluded — covered by /opt/homebrew/bin."""
|
||||
with patch("os.path.isdir", return_value=True), \
|
||||
patch("os.listdir", return_value=["node"]):
|
||||
result = _discover_homebrew_node_dirs()
|
||||
assert result == []
|
||||
|
||||
def test_handles_oserror_gracefully(self):
|
||||
"""Should return empty list if listdir raises OSError."""
|
||||
with patch("os.path.isdir", return_value=True), \
|
||||
patch("os.listdir", side_effect=OSError("Permission denied")):
|
||||
assert _discover_homebrew_node_dirs() == []
|
||||
|
||||
|
||||
class TestFindAgentBrowser:
|
||||
"""Tests for _find_agent_browser() Homebrew path search."""
|
||||
|
||||
def test_finds_in_current_path(self):
|
||||
"""Should return result from shutil.which if available on current PATH."""
|
||||
with patch("shutil.which", return_value="/usr/local/bin/agent-browser"):
|
||||
assert _find_agent_browser() == "/usr/local/bin/agent-browser"
|
||||
|
||||
def test_finds_in_homebrew_bin(self):
|
||||
"""Should search Homebrew dirs when not found on current PATH."""
|
||||
def mock_which(cmd, path=None):
|
||||
if path and "/opt/homebrew/bin" in path and cmd == "agent-browser":
|
||||
return "/opt/homebrew/bin/agent-browser"
|
||||
return None
|
||||
|
||||
with patch("shutil.which", side_effect=mock_which), \
|
||||
patch("os.path.isdir", return_value=True), \
|
||||
patch(
|
||||
"tools.browser_tool._discover_homebrew_node_dirs",
|
||||
return_value=[],
|
||||
):
|
||||
result = _find_agent_browser()
|
||||
assert result == "/opt/homebrew/bin/agent-browser"
|
||||
|
||||
def test_finds_npx_in_homebrew(self):
|
||||
"""Should find npx in Homebrew paths as a fallback."""
|
||||
def mock_which(cmd, path=None):
|
||||
if cmd == "agent-browser":
|
||||
return None
|
||||
if cmd == "npx":
|
||||
if path and "/opt/homebrew/bin" in path:
|
||||
return "/opt/homebrew/bin/npx"
|
||||
return None
|
||||
return None
|
||||
|
||||
# Mock Path.exists() to prevent the local node_modules check from matching
|
||||
original_path_exists = Path.exists
|
||||
|
||||
def mock_path_exists(self):
|
||||
if "node_modules" in str(self) and "agent-browser" in str(self):
|
||||
return False
|
||||
return original_path_exists(self)
|
||||
|
||||
with patch("shutil.which", side_effect=mock_which), \
|
||||
patch("os.path.isdir", return_value=True), \
|
||||
patch.object(Path, "exists", mock_path_exists), \
|
||||
patch(
|
||||
"tools.browser_tool._discover_homebrew_node_dirs",
|
||||
return_value=[],
|
||||
):
|
||||
result = _find_agent_browser()
|
||||
assert result == "npx agent-browser"
|
||||
|
||||
def test_raises_when_not_found(self):
|
||||
"""Should raise FileNotFoundError when nothing works."""
|
||||
original_path_exists = Path.exists
|
||||
|
||||
def mock_path_exists(self):
|
||||
if "node_modules" in str(self) and "agent-browser" in str(self):
|
||||
return False
|
||||
return original_path_exists(self)
|
||||
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("os.path.isdir", return_value=False), \
|
||||
patch.object(Path, "exists", mock_path_exists), \
|
||||
patch(
|
||||
"tools.browser_tool._discover_homebrew_node_dirs",
|
||||
return_value=[],
|
||||
):
|
||||
with pytest.raises(FileNotFoundError, match="agent-browser CLI not found"):
|
||||
_find_agent_browser()
|
||||
|
||||
|
||||
class TestRunBrowserCommandPathConstruction:
|
||||
"""Verify _run_browser_command() includes Homebrew node dirs in subprocess PATH."""
|
||||
|
||||
def test_subprocess_path_includes_homebrew_node_dirs(self, tmp_path):
|
||||
"""When _discover_homebrew_node_dirs returns dirs, they should appear
|
||||
in the subprocess env PATH passed to Popen."""
|
||||
captured_env = {}
|
||||
|
||||
# Create a mock Popen that captures the env dict
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait.return_value = 0
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
captured_env.update(kwargs.get("env", {}))
|
||||
return mock_proc
|
||||
|
||||
fake_session = {
|
||||
"session_name": "test-session",
|
||||
"session_id": "test-id",
|
||||
"cdp_url": None,
|
||||
}
|
||||
|
||||
# Write fake JSON output to the stdout temp file
|
||||
fake_json = json.dumps({"success": True})
|
||||
stdout_file = tmp_path / "stdout"
|
||||
stdout_file.write_text(fake_json)
|
||||
|
||||
fake_homebrew_dirs = [
|
||||
"/opt/homebrew/opt/node@24/bin",
|
||||
"/opt/homebrew/opt/node@20/bin",
|
||||
]
|
||||
|
||||
# We need os.path.isdir to return True for our fake dirs
|
||||
# but we also need real isdir for tmp_path operations
|
||||
real_isdir = os.path.isdir
|
||||
|
||||
def selective_isdir(p):
|
||||
if p in fake_homebrew_dirs or p.startswith(str(tmp_path)):
|
||||
return True
|
||||
if "/opt/homebrew/" in p:
|
||||
return True # _SANE_PATH dirs
|
||||
return real_isdir(p)
|
||||
|
||||
with patch("tools.browser_tool._find_agent_browser", return_value="/usr/local/bin/agent-browser"), \
|
||||
patch("tools.browser_tool._get_session_info", return_value=fake_session), \
|
||||
patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(tmp_path)), \
|
||||
patch("tools.browser_tool._discover_homebrew_node_dirs", return_value=fake_homebrew_dirs), \
|
||||
patch("os.path.isdir", side_effect=selective_isdir), \
|
||||
patch("subprocess.Popen", side_effect=capture_popen), \
|
||||
patch("os.open", return_value=99), \
|
||||
patch("os.close"), \
|
||||
patch("tools.interrupt.is_interrupted", return_value=False), \
|
||||
patch.dict(os.environ, {"PATH": "/usr/bin:/bin", "HOME": "/home/test"}, clear=True):
|
||||
# The function reads from temp files for stdout/stderr
|
||||
with patch("builtins.open", mock_open(read_data=fake_json)):
|
||||
_run_browser_command("test-task", "navigate", ["https://example.com"])
|
||||
|
||||
# Verify Homebrew node dirs made it into the subprocess PATH
|
||||
result_path = captured_env.get("PATH", "")
|
||||
assert "/opt/homebrew/opt/node@24/bin" in result_path
|
||||
assert "/opt/homebrew/opt/node@20/bin" in result_path
|
||||
assert "/opt/homebrew/bin" in result_path # from _SANE_PATH
|
||||
|
||||
def test_subprocess_path_includes_sane_path_homebrew(self, tmp_path):
|
||||
"""_SANE_PATH Homebrew entries should appear even without versioned node dirs."""
|
||||
captured_env = {}
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait.return_value = 0
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
captured_env.update(kwargs.get("env", {}))
|
||||
return mock_proc
|
||||
|
||||
fake_session = {
|
||||
"session_name": "test-session",
|
||||
"session_id": "test-id",
|
||||
"cdp_url": None,
|
||||
}
|
||||
|
||||
fake_json = json.dumps({"success": True})
|
||||
real_isdir = os.path.isdir
|
||||
|
||||
def selective_isdir(p):
|
||||
if "/opt/homebrew/" in p:
|
||||
return True
|
||||
if p.startswith(str(tmp_path)):
|
||||
return True
|
||||
return real_isdir(p)
|
||||
|
||||
with patch("tools.browser_tool._find_agent_browser", return_value="/usr/local/bin/agent-browser"), \
|
||||
patch("tools.browser_tool._get_session_info", return_value=fake_session), \
|
||||
patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(tmp_path)), \
|
||||
patch("tools.browser_tool._discover_homebrew_node_dirs", return_value=[]), \
|
||||
patch("os.path.isdir", side_effect=selective_isdir), \
|
||||
patch("subprocess.Popen", side_effect=capture_popen), \
|
||||
patch("os.open", return_value=99), \
|
||||
patch("os.close"), \
|
||||
patch("tools.interrupt.is_interrupted", return_value=False), \
|
||||
patch.dict(os.environ, {"PATH": "/usr/bin:/bin", "HOME": "/home/test"}, clear=True):
|
||||
with patch("builtins.open", mock_open(read_data=fake_json)):
|
||||
_run_browser_command("test-task", "navigate", ["https://example.com"])
|
||||
|
||||
result_path = captured_env.get("PATH", "")
|
||||
assert "/opt/homebrew/bin" in result_path
|
||||
assert "/opt/homebrew/sbin" in result_path
|
||||
@@ -309,3 +309,6 @@ class TestSearchHints:
|
||||
raw = search_tool(pattern="foo", offset=50, limit=50)
|
||||
assert "[Hint:" in raw
|
||||
assert "offset=100" in raw
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -288,3 +288,34 @@ class TestBlocklistCoverage:
|
||||
"DAYTONA_API_KEY",
|
||||
}
|
||||
assert extras.issubset(_HERMES_PROVIDER_ENV_BLOCKLIST)
|
||||
|
||||
|
||||
class TestSanePathIncludesHomebrew:
|
||||
"""Verify _SANE_PATH includes macOS Homebrew directories."""
|
||||
|
||||
def test_sane_path_includes_homebrew_bin(self):
|
||||
from tools.environments.local import _SANE_PATH
|
||||
assert "/opt/homebrew/bin" in _SANE_PATH
|
||||
|
||||
def test_sane_path_includes_homebrew_sbin(self):
|
||||
from tools.environments.local import _SANE_PATH
|
||||
assert "/opt/homebrew/sbin" in _SANE_PATH
|
||||
|
||||
def test_make_run_env_appends_homebrew_on_minimal_path(self):
|
||||
"""When PATH is minimal (no /usr/bin), _make_run_env should append
|
||||
_SANE_PATH which now includes Homebrew dirs."""
|
||||
from tools.environments.local import _make_run_env
|
||||
minimal_env = {"PATH": "/some/custom/bin"}
|
||||
with patch.dict(os.environ, minimal_env, clear=True):
|
||||
result = _make_run_env({})
|
||||
assert "/opt/homebrew/bin" in result["PATH"]
|
||||
assert "/opt/homebrew/sbin" in result["PATH"]
|
||||
|
||||
def test_make_run_env_does_not_duplicate_on_full_path(self):
|
||||
"""When PATH already has /usr/bin, _make_run_env should not append."""
|
||||
from tools.environments.local import _make_run_env
|
||||
full_env = {"PATH": "/usr/bin:/bin"}
|
||||
with patch.dict(os.environ, full_env, clear=True):
|
||||
result = _make_run_env({})
|
||||
# Should keep existing PATH unchanged
|
||||
assert result["PATH"] == "/usr/bin:/bin"
|
||||
|
||||
176
tests/tools/test_url_safety.py
Normal file
176
tests/tools/test_url_safety.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Tests for SSRF protection in url_safety module."""
|
||||
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
|
||||
from tools.url_safety import is_safe_url, _is_blocked_ip
|
||||
|
||||
import ipaddress
|
||||
import pytest
|
||||
|
||||
|
||||
class TestIsSafeUrl:
|
||||
def test_public_url_allowed(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert is_safe_url("https://example.com/image.png") is True
|
||||
|
||||
def test_localhost_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("127.0.0.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://localhost:8080/secret") is False
|
||||
|
||||
def test_loopback_ip_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("127.0.0.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://127.0.0.1/admin") is False
|
||||
|
||||
def test_private_10_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("10.0.0.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://internal-service.local/api") is False
|
||||
|
||||
def test_private_172_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("172.16.0.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://private.corp/data") is False
|
||||
|
||||
def test_private_192_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("192.168.1.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://router.local") is False
|
||||
|
||||
def test_link_local_169_254_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("169.254.169.254", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://169.254.169.254/latest/meta-data/") is False
|
||||
|
||||
def test_metadata_google_internal_blocked(self):
|
||||
assert is_safe_url("http://metadata.google.internal/computeMetadata/v1/") is False
|
||||
|
||||
def test_ipv6_loopback_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(10, 1, 6, "", ("::1", 0, 0, 0)),
|
||||
]):
|
||||
assert is_safe_url("http://[::1]:8080/") is False
|
||||
|
||||
def test_dns_failure_blocked(self):
|
||||
"""DNS failures now fail closed — block the request."""
|
||||
with patch("socket.getaddrinfo", side_effect=socket.gaierror("Name resolution failed")):
|
||||
assert is_safe_url("https://nonexistent.example.com") is False
|
||||
|
||||
def test_empty_url_blocked(self):
|
||||
assert is_safe_url("") is False
|
||||
|
||||
def test_no_hostname_blocked(self):
|
||||
assert is_safe_url("http://") is False
|
||||
|
||||
def test_public_ip_allowed(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert is_safe_url("https://example.com") is True
|
||||
|
||||
# ── New tests for hardened SSRF protection ──
|
||||
|
||||
def test_cgnat_100_64_blocked(self):
|
||||
"""100.64.0.0/10 (CGNAT/Shared Address Space) is NOT covered by
|
||||
ipaddress.is_private — must be blocked explicitly."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("100.64.0.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://some-cgnat-host.example/") is False
|
||||
|
||||
def test_cgnat_100_127_blocked(self):
|
||||
"""Upper end of CGNAT range (100.127.255.255)."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("100.127.255.254", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://tailscale-peer.example/") is False
|
||||
|
||||
def test_multicast_blocked(self):
|
||||
"""Multicast addresses (224.0.0.0/4) not caught by is_private."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("224.0.0.251", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://mdns-host.local/") is False
|
||||
|
||||
def test_multicast_ipv6_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(10, 1, 6, "", ("ff02::1", 0, 0, 0)),
|
||||
]):
|
||||
assert is_safe_url("http://[ff02::1]/") is False
|
||||
|
||||
def test_ipv4_mapped_ipv6_loopback_blocked(self):
|
||||
"""::ffff:127.0.0.1 — IPv4-mapped IPv6 loopback."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(10, 1, 6, "", ("::ffff:127.0.0.1", 0, 0, 0)),
|
||||
]):
|
||||
assert is_safe_url("http://[::ffff:127.0.0.1]/") is False
|
||||
|
||||
def test_ipv4_mapped_ipv6_metadata_blocked(self):
|
||||
"""::ffff:169.254.169.254 — IPv4-mapped IPv6 cloud metadata."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(10, 1, 6, "", ("::ffff:169.254.169.254", 0, 0, 0)),
|
||||
]):
|
||||
assert is_safe_url("http://[::ffff:169.254.169.254]/") is False
|
||||
|
||||
def test_unspecified_address_blocked(self):
|
||||
"""0.0.0.0 — unspecified address, can bind to all interfaces."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("0.0.0.0", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://0.0.0.0/") is False
|
||||
|
||||
def test_unexpected_error_fails_closed(self):
|
||||
"""Unexpected exceptions should block, not allow."""
|
||||
with patch("tools.url_safety.urlparse", side_effect=ValueError("bad url")):
|
||||
assert is_safe_url("http://evil.com/") is False
|
||||
|
||||
def test_metadata_goog_blocked(self):
|
||||
assert is_safe_url("http://metadata.goog/computeMetadata/v1/") is False
|
||||
|
||||
def test_ipv6_unique_local_blocked(self):
|
||||
"""fc00::/7 — IPv6 unique local addresses."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(10, 1, 6, "", ("fd12::1", 0, 0, 0)),
|
||||
]):
|
||||
assert is_safe_url("http://[fd12::1]/internal") is False
|
||||
|
||||
def test_non_cgnat_100_allowed(self):
|
||||
"""100.0.0.1 is NOT in CGNAT range (100.64.0.0/10), should be allowed."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("100.0.0.1", 0)),
|
||||
]):
|
||||
# 100.0.0.1 is a global IP, not in CGNAT range
|
||||
assert is_safe_url("http://legit-host.example/") is True
|
||||
|
||||
|
||||
class TestIsBlockedIp:
|
||||
"""Direct tests for the _is_blocked_ip helper."""
|
||||
|
||||
@pytest.mark.parametrize("ip_str", [
|
||||
"127.0.0.1", "10.0.0.1", "172.16.0.1", "192.168.1.1",
|
||||
"169.254.169.254", "0.0.0.0", "224.0.0.1", "255.255.255.255",
|
||||
"100.64.0.1", "100.100.100.100", "100.127.255.254",
|
||||
"::1", "fe80::1", "fc00::1", "fd12::1", "ff02::1",
|
||||
"::ffff:127.0.0.1", "::ffff:169.254.169.254",
|
||||
])
|
||||
def test_blocked_ips(self, ip_str):
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
assert _is_blocked_ip(ip) is True, f"{ip_str} should be blocked"
|
||||
|
||||
@pytest.mark.parametrize("ip_str", [
|
||||
"8.8.8.8", "93.184.216.34", "1.1.1.1", "100.0.0.1",
|
||||
"2606:4700::1", "2001:4860:4860::8888",
|
||||
])
|
||||
def test_allowed_ips(self, ip_str):
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
assert _is_blocked_ip(ip) is False, f"{ip_str} should be allowed"
|
||||
@@ -33,17 +33,30 @@ class TestValidateImageUrl:
|
||||
assert _validate_image_url("https://example.com/image.jpg") is True
|
||||
|
||||
def test_valid_http_url(self):
|
||||
assert _validate_image_url("http://cdn.example.org/photo.png") is True
|
||||
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert _validate_image_url("http://cdn.example.org/photo.png") is True
|
||||
|
||||
def test_valid_url_without_extension(self):
|
||||
"""CDN endpoints that redirect to images should still pass."""
|
||||
assert _validate_image_url("https://cdn.example.com/abcdef123") is True
|
||||
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert _validate_image_url("https://cdn.example.com/abcdef123") is True
|
||||
|
||||
def test_valid_url_with_query_params(self):
|
||||
assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
|
||||
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
|
||||
|
||||
def test_localhost_url_blocked_by_ssrf(self):
|
||||
"""localhost URLs are now blocked by SSRF protection."""
|
||||
assert _validate_image_url("http://localhost:8080/image.png") is False
|
||||
|
||||
def test_valid_url_with_port(self):
|
||||
assert _validate_image_url("http://localhost:8080/image.png") is True
|
||||
assert _validate_image_url("http://example.com:8080/image.png") is True
|
||||
|
||||
def test_valid_url_with_path_only(self):
|
||||
assert _validate_image_url("https://example.com/") is True
|
||||
|
||||
@@ -343,6 +343,8 @@ def test_browser_navigate_allows_when_shared_file_missing(monkeypatch, tmp_path)
|
||||
async def test_web_extract_short_circuits_blocked_url(monkeypatch):
|
||||
from tools import web_tools
|
||||
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr(
|
||||
web_tools,
|
||||
"check_website_access",
|
||||
@@ -389,6 +391,9 @@ def test_check_website_access_fails_open_on_malformed_config(tmp_path, monkeypat
|
||||
async def test_web_extract_blocks_redirected_final_url(monkeypatch):
|
||||
from tools import web_tools
|
||||
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
|
||||
def fake_check(url):
|
||||
if url == "https://allowed.test":
|
||||
return None
|
||||
@@ -428,6 +433,8 @@ async def test_web_crawl_short_circuits_blocked_url(monkeypatch):
|
||||
|
||||
# web_crawl_tool checks for Firecrawl env before website policy
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr(
|
||||
web_tools,
|
||||
"check_website_access",
|
||||
@@ -457,6 +464,8 @@ async def test_web_crawl_blocks_redirected_final_url(monkeypatch):
|
||||
|
||||
# web_crawl_tool checks for Firecrawl env before website policy
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
|
||||
def fake_check(url):
|
||||
if url == "https://allowed.test":
|
||||
|
||||
44
tools/ansi_strip.py
Normal file
44
tools/ansi_strip.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Strip ANSI escape sequences from subprocess output.
|
||||
|
||||
Used by terminal_tool, code_execution_tool, and process_registry to clean
|
||||
command output before returning it to the model. This prevents ANSI codes
|
||||
from entering the model's context — which is the root cause of models
|
||||
copying escape sequences into file writes.
|
||||
|
||||
Covers the full ECMA-48 spec: CSI (including private-mode ``?`` prefix,
|
||||
colon-separated params, intermediate bytes), OSC (BEL and ST terminators),
|
||||
DCS/SOS/PM/APC string sequences, nF multi-byte escapes, Fp/Fe/Fs
|
||||
single-byte escapes, and 8-bit C1 control characters.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
_ANSI_ESCAPE_RE = re.compile(
|
||||
r"\x1b"
|
||||
r"(?:"
|
||||
r"\[[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]" # CSI sequence
|
||||
r"|\][\s\S]*?(?:\x07|\x1b\\)" # OSC (BEL or ST terminator)
|
||||
r"|[PX^_][\s\S]*?(?:\x1b\\)" # DCS/SOS/PM/APC strings
|
||||
r"|[\x20-\x2f]+[\x30-\x7e]" # nF escape sequences
|
||||
r"|[\x30-\x7e]" # Fp/Fe/Fs single-byte
|
||||
r")"
|
||||
r"|\x9b[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]" # 8-bit CSI
|
||||
r"|\x9d[\s\S]*?(?:\x07|\x9c)" # 8-bit OSC
|
||||
r"|[\x80-\x9f]", # Other 8-bit C1 controls
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Fast-path check — skip full regex when no escape-like bytes are present.
|
||||
_HAS_ESCAPE = re.compile(r"[\x1b\x80-\x9f]")
|
||||
|
||||
|
||||
def strip_ansi(text: str) -> str:
|
||||
"""Remove ANSI escape sequences from text.
|
||||
|
||||
Returns the input unchanged (fast path) when no ESC or C1 bytes are
|
||||
present. Safe to call on any string — clean text passes through
|
||||
with negligible overhead.
|
||||
"""
|
||||
if not text or not _HAS_ESCAPE.search(text):
|
||||
return text
|
||||
return _ANSI_ESCAPE_RE.sub("", text)
|
||||
@@ -76,8 +76,35 @@ from tools.browser_providers.browser_use import BrowserUseProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Standard PATH entries for environments with minimal PATH (e.g. systemd services)
|
||||
_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
# Standard PATH entries for environments with minimal PATH (e.g. systemd services).
|
||||
# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon).
|
||||
_SANE_PATH = (
|
||||
"/opt/homebrew/bin:/opt/homebrew/sbin:"
|
||||
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
)
|
||||
|
||||
|
||||
def _discover_homebrew_node_dirs() -> list[str]:
|
||||
"""Find Homebrew versioned Node.js bin directories (e.g. node@20, node@24).
|
||||
|
||||
When Node is installed via ``brew install node@24`` and NOT linked into
|
||||
/opt/homebrew/bin, the binary lives only in /opt/homebrew/opt/node@24/bin/.
|
||||
This function discovers those paths so they can be added to subprocess PATH.
|
||||
"""
|
||||
dirs: list[str] = []
|
||||
homebrew_opt = "/opt/homebrew/opt"
|
||||
if not os.path.isdir(homebrew_opt):
|
||||
return dirs
|
||||
try:
|
||||
for entry in os.listdir(homebrew_opt):
|
||||
if entry.startswith("node") and entry != "node":
|
||||
# e.g. node@20, node@24
|
||||
bin_dir = os.path.join(homebrew_opt, entry, "bin")
|
||||
if os.path.isdir(bin_dir):
|
||||
dirs.append(bin_dir)
|
||||
except OSError:
|
||||
pass
|
||||
return dirs
|
||||
|
||||
# Throttle screenshot cleanup to avoid repeated full directory scans.
|
||||
_last_screenshot_cleanup_by_dir: dict[str, float] = {}
|
||||
@@ -619,7 +646,8 @@ def _find_agent_browser() -> str:
|
||||
"""
|
||||
Find the agent-browser CLI executable.
|
||||
|
||||
Checks in order: PATH, local node_modules/.bin/, npx fallback.
|
||||
Checks in order: current PATH, Homebrew/common bin dirs, Hermes-managed
|
||||
node, local node_modules/.bin/, npx fallback.
|
||||
|
||||
Returns:
|
||||
Path to agent-browser executable
|
||||
@@ -632,15 +660,36 @@ def _find_agent_browser() -> str:
|
||||
which_result = shutil.which("agent-browser")
|
||||
if which_result:
|
||||
return which_result
|
||||
|
||||
|
||||
# Build an extended search PATH including Homebrew and Hermes-managed dirs.
|
||||
# This covers macOS where the process PATH may not include Homebrew paths.
|
||||
extra_dirs: list[str] = []
|
||||
for d in ["/opt/homebrew/bin", "/usr/local/bin"]:
|
||||
if os.path.isdir(d):
|
||||
extra_dirs.append(d)
|
||||
extra_dirs.extend(_discover_homebrew_node_dirs())
|
||||
|
||||
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
hermes_node_bin = str(hermes_home / "node" / "bin")
|
||||
if os.path.isdir(hermes_node_bin):
|
||||
extra_dirs.append(hermes_node_bin)
|
||||
|
||||
if extra_dirs:
|
||||
extended_path = os.pathsep.join(extra_dirs)
|
||||
which_result = shutil.which("agent-browser", path=extended_path)
|
||||
if which_result:
|
||||
return which_result
|
||||
|
||||
# Check local node_modules/.bin/ (npm install in repo root)
|
||||
repo_root = Path(__file__).parent.parent
|
||||
local_bin = repo_root / "node_modules" / ".bin" / "agent-browser"
|
||||
if local_bin.exists():
|
||||
return str(local_bin)
|
||||
|
||||
# Check common npx locations
|
||||
# Check common npx locations (also search extended dirs)
|
||||
npx_path = shutil.which("npx")
|
||||
if not npx_path and extra_dirs:
|
||||
npx_path = shutil.which("npx", path=os.pathsep.join(extra_dirs))
|
||||
if npx_path:
|
||||
return "npx agent-browser"
|
||||
|
||||
@@ -742,13 +791,18 @@ def _run_browser_command(
|
||||
|
||||
browser_env = {**os.environ}
|
||||
|
||||
# Ensure PATH includes Hermes-managed Node first, then standard system dirs.
|
||||
# Ensure PATH includes Hermes-managed Node first, Homebrew versioned
|
||||
# node dirs (for macOS ``brew install node@24``), then standard system dirs.
|
||||
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
hermes_node_bin = str(hermes_home / "node" / "bin")
|
||||
|
||||
existing_path = browser_env.get("PATH", "")
|
||||
path_parts = [p for p in existing_path.split(":") if p]
|
||||
candidate_dirs = [hermes_node_bin] + [p for p in _SANE_PATH.split(":") if p]
|
||||
candidate_dirs = (
|
||||
[hermes_node_bin]
|
||||
+ _discover_homebrew_node_dirs()
|
||||
+ [p for p in _SANE_PATH.split(":") if p]
|
||||
)
|
||||
|
||||
for part in reversed(candidate_dirs):
|
||||
if os.path.isdir(part) and part not in path_parts:
|
||||
|
||||
@@ -577,6 +577,12 @@ def execute_code(
|
||||
server_sock = None # prevent double close in finally
|
||||
rpc_thread.join(timeout=3)
|
||||
|
||||
# Strip ANSI escape sequences so the model never sees terminal
|
||||
# formatting — prevents it from copying escapes into file writes.
|
||||
from tools.ansi_strip import strip_ansi
|
||||
stdout_text = strip_ansi(stdout_text)
|
||||
stderr_text = strip_ansi(stderr_text)
|
||||
|
||||
# Build response
|
||||
result: Dict[str, Any] = {
|
||||
"status": status,
|
||||
|
||||
@@ -254,7 +254,12 @@ def _clean_shell_noise(output: str) -> str:
|
||||
return result
|
||||
|
||||
|
||||
_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
# Standard PATH entries for environments with minimal PATH (e.g. systemd services).
|
||||
# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon).
|
||||
_SANE_PATH = (
|
||||
"/opt/homebrew/bin:/opt/homebrew/sbin:"
|
||||
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
)
|
||||
|
||||
|
||||
def _make_run_env(env: dict) -> dict:
|
||||
|
||||
@@ -433,9 +433,13 @@ class ShellFileOperations(FileOperations):
|
||||
slash_idx = rest.find('/')
|
||||
username = rest[:slash_idx] if slash_idx >= 0 else rest
|
||||
if username and re.fullmatch(r'[a-zA-Z0-9._-]+', username):
|
||||
expand_result = self._exec(f"echo {path}")
|
||||
# Only expand ~username (not the full path) to avoid shell
|
||||
# injection via path suffixes like "~user/$(malicious)".
|
||||
expand_result = self._exec(f"echo ~{username}")
|
||||
if expand_result.exit_code == 0 and expand_result.stdout.strip():
|
||||
return expand_result.stdout.strip()
|
||||
user_home = expand_result.stdout.strip()
|
||||
suffix = path[1 + len(username):] # e.g. "/rest/of/path"
|
||||
return user_home + suffix
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import errno
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional
|
||||
from tools.file_operations import ShellFileOperations
|
||||
@@ -13,17 +12,6 @@ from agent.redact import redact_sensitive_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Regex to match ANSI escape sequences (CSI codes, OSC codes, simple escapes).
|
||||
# Models occasionally copy these from terminal output into file content.
|
||||
_ANSI_ESCAPE_RE = re.compile(r"\x1b\[[0-9;]*[A-Za-z]|\x1b\][^\x07]*\x07|\x1b[()][A-B012]|\x1b[=>]")
|
||||
|
||||
|
||||
def _strip_ansi(text: str) -> str:
|
||||
"""Remove ANSI escape sequences from text destined for file writes."""
|
||||
if not text or "\x1b" not in text:
|
||||
return text
|
||||
return _ANSI_ESCAPE_RE.sub("", text)
|
||||
|
||||
|
||||
_EXPECTED_WRITE_ERRNOS = {errno.EACCES, errno.EPERM, errno.EROFS}
|
||||
|
||||
@@ -301,7 +289,6 @@ def notify_other_tool_call(task_id: str = "default"):
|
||||
def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
||||
"""Write content to a file."""
|
||||
try:
|
||||
content = _strip_ansi(content)
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.write_file(path, content)
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
@@ -325,13 +312,10 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||
return json.dumps({"error": "path required"})
|
||||
if old_string is None or new_string is None:
|
||||
return json.dumps({"error": "old_string and new_string required"})
|
||||
old_string = _strip_ansi(old_string)
|
||||
new_string = _strip_ansi(new_string)
|
||||
result = file_ops.patch_replace(path, old_string, new_string, replace_all)
|
||||
elif mode == "patch":
|
||||
if not patch:
|
||||
return json.dumps({"error": "patch content required"})
|
||||
patch = _strip_ansi(patch)
|
||||
result = file_ops.patch_v4a(patch)
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown mode: {mode}"})
|
||||
|
||||
@@ -426,12 +426,14 @@ class ProcessRegistry:
|
||||
|
||||
def poll(self, session_id: str) -> dict:
|
||||
"""Check status and get new output for a background process."""
|
||||
from tools.ansi_strip import strip_ansi
|
||||
|
||||
session = self.get(session_id)
|
||||
if session is None:
|
||||
return {"status": "not_found", "error": f"No process with ID {session_id}"}
|
||||
|
||||
with session._lock:
|
||||
output_preview = session.output_buffer[-1000:] if session.output_buffer else ""
|
||||
output_preview = strip_ansi(session.output_buffer[-1000:]) if session.output_buffer else ""
|
||||
|
||||
result = {
|
||||
"session_id": session.id,
|
||||
@@ -450,12 +452,14 @@ class ProcessRegistry:
|
||||
|
||||
def read_log(self, session_id: str, offset: int = 0, limit: int = 200) -> dict:
|
||||
"""Read the full output log with optional pagination by lines."""
|
||||
from tools.ansi_strip import strip_ansi
|
||||
|
||||
session = self.get(session_id)
|
||||
if session is None:
|
||||
return {"status": "not_found", "error": f"No process with ID {session_id}"}
|
||||
|
||||
with session._lock:
|
||||
full_output = session.output_buffer
|
||||
full_output = strip_ansi(session.output_buffer)
|
||||
|
||||
lines = full_output.splitlines()
|
||||
total_lines = len(lines)
|
||||
@@ -486,6 +490,7 @@ class ProcessRegistry:
|
||||
dict with status ("exited", "timeout", "interrupted", "not_found")
|
||||
and output snapshot.
|
||||
"""
|
||||
from tools.ansi_strip import strip_ansi
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
|
||||
default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180"))
|
||||
@@ -513,7 +518,7 @@ class ProcessRegistry:
|
||||
result = {
|
||||
"status": "exited",
|
||||
"exit_code": session.exit_code,
|
||||
"output": session.output_buffer[-2000:],
|
||||
"output": strip_ansi(session.output_buffer[-2000:]),
|
||||
}
|
||||
if timeout_note:
|
||||
result["timeout_note"] = timeout_note
|
||||
@@ -522,7 +527,7 @@ class ProcessRegistry:
|
||||
if _interrupt_event.is_set():
|
||||
result = {
|
||||
"status": "interrupted",
|
||||
"output": session.output_buffer[-1000:],
|
||||
"output": strip_ansi(session.output_buffer[-1000:]),
|
||||
"note": "User sent a new message -- wait interrupted",
|
||||
}
|
||||
if timeout_note:
|
||||
@@ -533,7 +538,7 @@ class ProcessRegistry:
|
||||
|
||||
result = {
|
||||
"status": "timeout",
|
||||
"output": session.output_buffer[-1000:],
|
||||
"output": strip_ansi(session.output_buffer[-1000:]),
|
||||
}
|
||||
if timeout_note:
|
||||
result["timeout_note"] = timeout_note
|
||||
|
||||
@@ -1163,6 +1163,11 @@ def terminal_tool(
|
||||
)
|
||||
output = output[:head_chars] + truncated_notice + output[-tail_chars:]
|
||||
|
||||
# Strip ANSI escape sequences so the model never sees terminal
|
||||
# formatting — prevents it from copying escapes into file writes.
|
||||
from tools.ansi_strip import strip_ansi
|
||||
output = strip_ansi(output)
|
||||
|
||||
# Redact secrets from command output (catches env/printenv leaking keys)
|
||||
from agent.redact import redact_sensitive_text
|
||||
output = redact_sensitive_text(output.strip()) if output else ""
|
||||
|
||||
96
tools/url_safety.py
Normal file
96
tools/url_safety.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""URL safety checks — blocks requests to private/internal network addresses.
|
||||
|
||||
Prevents SSRF (Server-Side Request Forgery) where a malicious prompt or
|
||||
skill could trick the agent into fetching internal resources like cloud
|
||||
metadata endpoints (169.254.169.254), localhost services, or private
|
||||
network hosts.
|
||||
|
||||
Limitations (documented, not fixable at pre-flight level):
|
||||
- DNS rebinding (TOCTOU): an attacker-controlled DNS server with TTL=0
|
||||
can return a public IP for the check, then a private IP for the actual
|
||||
connection. Fixing this requires connection-level validation (e.g.
|
||||
Python's Champion library or an egress proxy like Stripe's Smokescreen).
|
||||
- Redirect-based bypass in vision_tools is mitigated by an httpx event
|
||||
hook that re-validates each redirect target. Web tools use third-party
|
||||
SDKs (Firecrawl/Tavily) where redirect handling is on their servers.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hostnames that should always be blocked regardless of IP resolution
|
||||
_BLOCKED_HOSTNAMES = frozenset({
|
||||
"metadata.google.internal",
|
||||
"metadata.goog",
|
||||
})
|
||||
|
||||
# 100.64.0.0/10 (CGNAT / Shared Address Space, RFC 6598) is NOT covered by
|
||||
# ipaddress.is_private — it returns False for both is_private and is_global.
|
||||
# Must be blocked explicitly. Used by carrier-grade NAT, Tailscale/WireGuard
|
||||
# VPNs, and some cloud internal networks.
|
||||
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
|
||||
|
||||
|
||||
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
"""Return True if the IP should be blocked for SSRF protection."""
|
||||
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
|
||||
return True
|
||||
if ip.is_multicast or ip.is_unspecified:
|
||||
return True
|
||||
# CGNAT range not covered by is_private
|
||||
if ip in _CGNAT_NETWORK:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_safe_url(url: str) -> bool:
|
||||
"""Return True if the URL target is not a private/internal address.
|
||||
|
||||
Resolves the hostname to an IP and checks against private ranges.
|
||||
Fails closed: DNS errors and unexpected exceptions block the request.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = (parsed.hostname or "").strip().lower()
|
||||
if not hostname:
|
||||
return False
|
||||
|
||||
# Block known internal hostnames
|
||||
if hostname in _BLOCKED_HOSTNAMES:
|
||||
logger.warning("Blocked request to internal hostname: %s", hostname)
|
||||
return False
|
||||
|
||||
# Try to resolve and check IP
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||
except socket.gaierror:
|
||||
# DNS resolution failed — fail closed. If DNS can't resolve it,
|
||||
# the HTTP client will also fail, so blocking loses nothing.
|
||||
logger.warning("Blocked request — DNS resolution failed for: %s", hostname)
|
||||
return False
|
||||
|
||||
for family, _, _, _, sockaddr in addr_info:
|
||||
ip_str = sockaddr[0]
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if _is_blocked_ip(ip):
|
||||
logger.warning(
|
||||
"Blocked request to private/internal address: %s -> %s",
|
||||
hostname, ip_str,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
# Fail closed on unexpected errors — don't let parsing edge cases
|
||||
# become SSRF bypass vectors
|
||||
logger.warning("Blocked request — URL safety check error for %s: %s", url, exc)
|
||||
return False
|
||||
@@ -69,7 +69,12 @@ def _validate_image_url(url: str) -> bool:
|
||||
if not parsed.netloc:
|
||||
return False
|
||||
|
||||
return True # Allow all well-formed HTTP/HTTPS URLs for flexibility
|
||||
# Block private/internal addresses to prevent SSRF
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
|
||||
@@ -92,12 +97,33 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
|
||||
# Create parent directories if they don't exist
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def _ssrf_redirect_guard(response):
|
||||
"""Re-validate each redirect target to prevent redirect-based SSRF.
|
||||
|
||||
Without this, an attacker can host a public URL that 302-redirects
|
||||
to http://169.254.169.254/ and bypass the pre-flight is_safe_url check.
|
||||
|
||||
Must be async because httpx.AsyncClient awaits event hooks.
|
||||
"""
|
||||
if response.is_redirect and response.next_request:
|
||||
redirect_url = str(response.next_request.url)
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(redirect_url):
|
||||
raise ValueError(
|
||||
f"Blocked redirect to private/internal address: {redirect_url}"
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Download the image with appropriate headers using async httpx
|
||||
# Enable follow_redirects to handle image CDNs that redirect (e.g., Imgur, Picsum)
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
# SSRF: event_hooks validates each redirect target against private IP ranges
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
event_hooks={"response": [_ssrf_redirect_guard]},
|
||||
) as client:
|
||||
response = await client.get(
|
||||
image_url,
|
||||
headers={
|
||||
|
||||
@@ -46,6 +46,7 @@ import httpx
|
||||
from firecrawl import Firecrawl
|
||||
from agent.auxiliary_client import async_call_llm
|
||||
from tools.debug_helpers import DebugSession
|
||||
from tools.url_safety import is_safe_url
|
||||
from tools.website_policy import check_website_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -861,136 +862,155 @@ async def web_extract_tool(
|
||||
try:
|
||||
logger.info("Extracting content from %d URL(s)", len(urls))
|
||||
|
||||
# Dispatch to the configured backend
|
||||
backend = _get_backend()
|
||||
|
||||
if backend == "parallel":
|
||||
results = await _parallel_extract(urls)
|
||||
elif backend == "tavily":
|
||||
logger.info("Tavily extract: %d URL(s)", len(urls))
|
||||
raw = _tavily_request("extract", {
|
||||
"urls": urls,
|
||||
"include_images": False,
|
||||
})
|
||||
results = _normalize_tavily_documents(raw, fallback_url=urls[0] if urls else "")
|
||||
else:
|
||||
# ── Firecrawl extraction ──
|
||||
# Determine requested formats for Firecrawl v2
|
||||
formats: List[str] = []
|
||||
if format == "markdown":
|
||||
formats = ["markdown"]
|
||||
elif format == "html":
|
||||
formats = ["html"]
|
||||
# ── SSRF protection — filter out private/internal URLs before any backend ──
|
||||
safe_urls = []
|
||||
ssrf_blocked: List[Dict[str, Any]] = []
|
||||
for url in urls:
|
||||
if not is_safe_url(url):
|
||||
ssrf_blocked.append({
|
||||
"url": url, "title": "", "content": "",
|
||||
"error": "Blocked: URL targets a private or internal network address",
|
||||
})
|
||||
else:
|
||||
# Default: request markdown for LLM-readiness and include html as backup
|
||||
formats = ["markdown", "html"]
|
||||
safe_urls.append(url)
|
||||
|
||||
# Always use individual scraping for simplicity and reliability
|
||||
# Batch scraping adds complexity without much benefit for small numbers of URLs
|
||||
results: List[Dict[str, Any]] = []
|
||||
# Dispatch only safe URLs to the configured backend
|
||||
if not safe_urls:
|
||||
results = []
|
||||
else:
|
||||
backend = _get_backend()
|
||||
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
for url in urls:
|
||||
if _is_interrupted():
|
||||
results.append({"url": url, "error": "Interrupted", "title": ""})
|
||||
continue
|
||||
if backend == "parallel":
|
||||
results = await _parallel_extract(safe_urls)
|
||||
elif backend == "tavily":
|
||||
logger.info("Tavily extract: %d URL(s)", len(safe_urls))
|
||||
raw = _tavily_request("extract", {
|
||||
"urls": safe_urls,
|
||||
"include_images": False,
|
||||
})
|
||||
results = _normalize_tavily_documents(raw, fallback_url=safe_urls[0] if safe_urls else "")
|
||||
else:
|
||||
# ── Firecrawl extraction ──
|
||||
# Determine requested formats for Firecrawl v2
|
||||
formats: List[str] = []
|
||||
if format == "markdown":
|
||||
formats = ["markdown"]
|
||||
elif format == "html":
|
||||
formats = ["html"]
|
||||
else:
|
||||
# Default: request markdown for LLM-readiness and include html as backup
|
||||
formats = ["markdown", "html"]
|
||||
|
||||
# Website policy check — block before fetching
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"])
|
||||
results.append({
|
||||
"url": url, "title": "", "content": "",
|
||||
"error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
|
||||
})
|
||||
continue
|
||||
# Always use individual scraping for simplicity and reliability
|
||||
# Batch scraping adds complexity without much benefit for small numbers of URLs
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
try:
|
||||
logger.info("Scraping: %s", url)
|
||||
scrape_result = _get_firecrawl_client().scrape(
|
||||
url=url,
|
||||
formats=formats
|
||||
)
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
for url in safe_urls:
|
||||
if _is_interrupted():
|
||||
results.append({"url": url, "error": "Interrupted", "title": ""})
|
||||
continue
|
||||
|
||||
# Process the result - properly handle object serialization
|
||||
metadata = {}
|
||||
title = ""
|
||||
content_markdown = None
|
||||
content_html = None
|
||||
|
||||
# Extract data from the scrape result
|
||||
if hasattr(scrape_result, 'model_dump'):
|
||||
# Pydantic model - use model_dump to get dict
|
||||
result_dict = scrape_result.model_dump()
|
||||
content_markdown = result_dict.get('markdown')
|
||||
content_html = result_dict.get('html')
|
||||
metadata = result_dict.get('metadata', {})
|
||||
elif hasattr(scrape_result, '__dict__'):
|
||||
# Regular object with attributes
|
||||
content_markdown = getattr(scrape_result, 'markdown', None)
|
||||
content_html = getattr(scrape_result, 'html', None)
|
||||
|
||||
# Handle metadata - convert to dict if it's an object
|
||||
metadata_obj = getattr(scrape_result, 'metadata', {})
|
||||
if hasattr(metadata_obj, 'model_dump'):
|
||||
metadata = metadata_obj.model_dump()
|
||||
elif hasattr(metadata_obj, '__dict__'):
|
||||
metadata = metadata_obj.__dict__
|
||||
elif isinstance(metadata_obj, dict):
|
||||
metadata = metadata_obj
|
||||
else:
|
||||
metadata = {}
|
||||
elif isinstance(scrape_result, dict):
|
||||
# Already a dictionary
|
||||
content_markdown = scrape_result.get('markdown')
|
||||
content_html = scrape_result.get('html')
|
||||
metadata = scrape_result.get('metadata', {})
|
||||
|
||||
# Ensure metadata is a dict (not an object)
|
||||
if not isinstance(metadata, dict):
|
||||
if hasattr(metadata, 'model_dump'):
|
||||
metadata = metadata.model_dump()
|
||||
elif hasattr(metadata, '__dict__'):
|
||||
metadata = metadata.__dict__
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Get title from metadata
|
||||
title = metadata.get("title", "")
|
||||
|
||||
# Re-check final URL after redirect
|
||||
final_url = metadata.get("sourceURL", url)
|
||||
final_blocked = check_website_access(final_url)
|
||||
if final_blocked:
|
||||
logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"])
|
||||
# Website policy check — block before fetching
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"])
|
||||
results.append({
|
||||
"url": final_url, "title": title, "content": "", "raw_content": "",
|
||||
"error": final_blocked["message"],
|
||||
"blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]},
|
||||
"url": url, "title": "", "content": "",
|
||||
"error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
|
||||
})
|
||||
continue
|
||||
|
||||
# Choose content based on requested format
|
||||
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
|
||||
try:
|
||||
logger.info("Scraping: %s", url)
|
||||
scrape_result = _get_firecrawl_client().scrape(
|
||||
url=url,
|
||||
formats=formats
|
||||
)
|
||||
|
||||
results.append({
|
||||
"url": final_url,
|
||||
"title": title,
|
||||
"content": chosen_content,
|
||||
"raw_content": chosen_content,
|
||||
"metadata": metadata # Now guaranteed to be a dict
|
||||
})
|
||||
# Process the result - properly handle object serialization
|
||||
metadata = {}
|
||||
title = ""
|
||||
content_markdown = None
|
||||
content_html = None
|
||||
|
||||
except Exception as scrape_err:
|
||||
logger.debug("Scrape failed for %s: %s", url, scrape_err)
|
||||
results.append({
|
||||
"url": url,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": str(scrape_err)
|
||||
})
|
||||
# Extract data from the scrape result
|
||||
if hasattr(scrape_result, 'model_dump'):
|
||||
# Pydantic model - use model_dump to get dict
|
||||
result_dict = scrape_result.model_dump()
|
||||
content_markdown = result_dict.get('markdown')
|
||||
content_html = result_dict.get('html')
|
||||
metadata = result_dict.get('metadata', {})
|
||||
elif hasattr(scrape_result, '__dict__'):
|
||||
# Regular object with attributes
|
||||
content_markdown = getattr(scrape_result, 'markdown', None)
|
||||
content_html = getattr(scrape_result, 'html', None)
|
||||
|
||||
# Handle metadata - convert to dict if it's an object
|
||||
metadata_obj = getattr(scrape_result, 'metadata', {})
|
||||
if hasattr(metadata_obj, 'model_dump'):
|
||||
metadata = metadata_obj.model_dump()
|
||||
elif hasattr(metadata_obj, '__dict__'):
|
||||
metadata = metadata_obj.__dict__
|
||||
elif isinstance(metadata_obj, dict):
|
||||
metadata = metadata_obj
|
||||
else:
|
||||
metadata = {}
|
||||
elif isinstance(scrape_result, dict):
|
||||
# Already a dictionary
|
||||
content_markdown = scrape_result.get('markdown')
|
||||
content_html = scrape_result.get('html')
|
||||
metadata = scrape_result.get('metadata', {})
|
||||
|
||||
# Ensure metadata is a dict (not an object)
|
||||
if not isinstance(metadata, dict):
|
||||
if hasattr(metadata, 'model_dump'):
|
||||
metadata = metadata.model_dump()
|
||||
elif hasattr(metadata, '__dict__'):
|
||||
metadata = metadata.__dict__
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Get title from metadata
|
||||
title = metadata.get("title", "")
|
||||
|
||||
# Re-check final URL after redirect
|
||||
final_url = metadata.get("sourceURL", url)
|
||||
final_blocked = check_website_access(final_url)
|
||||
if final_blocked:
|
||||
logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"])
|
||||
results.append({
|
||||
"url": final_url, "title": title, "content": "", "raw_content": "",
|
||||
"error": final_blocked["message"],
|
||||
"blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]},
|
||||
})
|
||||
continue
|
||||
|
||||
# Choose content based on requested format
|
||||
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
|
||||
|
||||
results.append({
|
||||
"url": final_url,
|
||||
"title": title,
|
||||
"content": chosen_content,
|
||||
"raw_content": chosen_content,
|
||||
"metadata": metadata # Now guaranteed to be a dict
|
||||
})
|
||||
|
||||
except Exception as scrape_err:
|
||||
logger.debug("Scrape failed for %s: %s", url, scrape_err)
|
||||
results.append({
|
||||
"url": url,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": str(scrape_err)
|
||||
})
|
||||
|
||||
# Merge any SSRF-blocked results back in
|
||||
if ssrf_blocked:
|
||||
results = ssrf_blocked + results
|
||||
|
||||
response = {"results": results}
|
||||
|
||||
@@ -1173,6 +1193,11 @@ async def web_crawl_tool(
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
url = f'https://{url}'
|
||||
|
||||
# SSRF protection — block private/internal addresses
|
||||
if not is_safe_url(url):
|
||||
return json.dumps({"results": [{"url": url, "title": "", "content": "",
|
||||
"error": "Blocked: URL targets a private or internal network address"}]}, ensure_ascii=False)
|
||||
|
||||
# Website policy check
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
@@ -1258,6 +1283,11 @@ async def web_crawl_tool(
|
||||
instructions_text = f" with instructions: '{instructions}'" if instructions else ""
|
||||
logger.info("Crawling %s%s", url, instructions_text)
|
||||
|
||||
# SSRF protection — block private/internal addresses
|
||||
if not is_safe_url(url):
|
||||
return json.dumps({"results": [{"url": url, "title": "", "content": "",
|
||||
"error": "Blocked: URL targets a private or internal network address"}]}, ensure_ascii=False)
|
||||
|
||||
# Website policy check — block before crawling
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
|
||||
Reference in New Issue
Block a user