Compare commits
5 Commits
fix/modal-
...
feat/file-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6aba50f5ba | ||
|
|
a562550af3 | ||
|
|
37c478cf2f | ||
|
|
27eeea0555 | ||
|
|
fd73937ec8 |
@@ -8458,23 +8458,11 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Centralized logging — agent.log (INFO+) and errors.log (WARNING+).
|
||||
# Centralized logging — agent.log (INFO+), errors.log (WARNING+),
|
||||
# and gateway.log (INFO+, gateway-component records only).
|
||||
# Idempotent, so repeated calls from AIAgent.__init__ won't duplicate.
|
||||
from hermes_logging import setup_logging
|
||||
log_dir = setup_logging(hermes_home=_hermes_home, mode="gateway")
|
||||
|
||||
# Gateway-specific rotating log — captures all gateway-level messages
|
||||
# (session management, platform adapters, slash commands, etc.).
|
||||
from agent.redact import RedactingFormatter
|
||||
from hermes_logging import _add_rotating_handler
|
||||
_add_rotating_handler(
|
||||
logging.getLogger(),
|
||||
log_dir / 'gateway.log',
|
||||
level=logging.INFO,
|
||||
max_bytes=5 * 1024 * 1024,
|
||||
backup_count=3,
|
||||
formatter=RedactingFormatter('%(asctime)s %(levelname)s %(name)s: %(message)s'),
|
||||
)
|
||||
setup_logging(hermes_home=_hermes_home, mode="gateway")
|
||||
|
||||
# Optional stderr handler — level driven by -v/-q flags on the CLI.
|
||||
# verbosity=None (-q/--quiet): no stderr output
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
"""``hermes logs`` — view and filter Hermes log files.
|
||||
|
||||
Supports tailing, following, session filtering, level filtering, and
|
||||
relative time ranges. All log files live under ``~/.hermes/logs/``.
|
||||
Supports tailing, following, session filtering, level filtering,
|
||||
component filtering, and relative time ranges. All log files live
|
||||
under ``~/.hermes/logs/``.
|
||||
|
||||
Usage examples::
|
||||
|
||||
hermes logs # last 50 lines of agent.log
|
||||
hermes logs -f # follow agent.log in real time
|
||||
hermes logs errors # last 50 lines of errors.log
|
||||
hermes logs gateway -n 100 # last 100 lines of gateway.log
|
||||
hermes logs gateway -n 100 # last 100 lines of gateway.log
|
||||
hermes logs --level WARNING # only WARNING+ lines
|
||||
hermes logs --session abc123 # filter by session ID substring
|
||||
hermes logs --component tools # only tool-related lines
|
||||
hermes logs --since 1h # lines from the last hour
|
||||
hermes logs --since 30m -f # follow, starting 30 min ago
|
||||
"""
|
||||
@@ -20,7 +22,7 @@ import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from hermes_constants import get_hermes_home, display_hermes_home
|
||||
|
||||
@@ -38,6 +40,15 @@ _TS_RE = re.compile(r"^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})")
|
||||
# Level extraction — matches " INFO ", " WARNING ", " ERROR ", " DEBUG ", " CRITICAL "
|
||||
_LEVEL_RE = re.compile(r"\s(DEBUG|INFO|WARNING|ERROR|CRITICAL)\s")
|
||||
|
||||
# Logger name extraction — after level and optional session tag, the next
|
||||
# non-space token before ":" is the logger name.
|
||||
# Matches: "INFO gateway.run:" or "INFO [sess_abc] tools.terminal_tool:"
|
||||
_LOGGER_NAME_RE = re.compile(
|
||||
r"\s(?:DEBUG|INFO|WARNING|ERROR|CRITICAL)" # level
|
||||
r"(?:\s+\[.*?\])?" # optional session tag
|
||||
r"\s+(\S+):" # logger name
|
||||
)
|
||||
|
||||
# Level ordering for >= filtering
|
||||
_LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARNING": 2, "ERROR": 3, "CRITICAL": 4}
|
||||
|
||||
@@ -79,12 +90,27 @@ def _extract_level(line: str) -> Optional[str]:
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def _extract_logger_name(line: str) -> Optional[str]:
|
||||
"""Extract the logger name from a log line."""
|
||||
m = _LOGGER_NAME_RE.search(line)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def _line_matches_component(line: str, prefixes: Sequence[str]) -> bool:
|
||||
"""Check if a log line's logger name starts with any of *prefixes*."""
|
||||
name = _extract_logger_name(line)
|
||||
if name is None:
|
||||
return False
|
||||
return name.startswith(tuple(prefixes))
|
||||
|
||||
|
||||
def _matches_filters(
|
||||
line: str,
|
||||
*,
|
||||
min_level: Optional[str] = None,
|
||||
session_filter: Optional[str] = None,
|
||||
since: Optional[datetime] = None,
|
||||
component_prefixes: Optional[Sequence[str]] = None,
|
||||
) -> bool:
|
||||
"""Check if a log line passes all active filters."""
|
||||
if since is not None:
|
||||
@@ -102,6 +128,10 @@ def _matches_filters(
|
||||
if session_filter not in line:
|
||||
return False
|
||||
|
||||
if component_prefixes is not None:
|
||||
if not _line_matches_component(line, component_prefixes):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -113,6 +143,7 @@ def tail_log(
|
||||
level: Optional[str] = None,
|
||||
session: Optional[str] = None,
|
||||
since: Optional[str] = None,
|
||||
component: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Read and display log lines, optionally following in real time.
|
||||
|
||||
@@ -130,6 +161,8 @@ def tail_log(
|
||||
Session ID substring to filter on.
|
||||
since
|
||||
Relative time string (e.g. ``"1h"``, ``"30m"``).
|
||||
component
|
||||
Component name to filter by (e.g. ``"gateway"``, ``"tools"``).
|
||||
"""
|
||||
filename = LOG_FILES.get(log_name)
|
||||
if filename is None:
|
||||
@@ -155,13 +188,29 @@ def tail_log(
|
||||
print(f"Invalid --level: {level!r}. Use DEBUG, INFO, WARNING, ERROR, or CRITICAL.")
|
||||
sys.exit(1)
|
||||
|
||||
has_filters = min_level is not None or session is not None or since_dt is not None
|
||||
# Resolve component to logger name prefixes
|
||||
component_prefixes = None
|
||||
if component:
|
||||
from hermes_logging import COMPONENT_PREFIXES
|
||||
component_lower = component.lower()
|
||||
if component_lower not in COMPONENT_PREFIXES:
|
||||
available = ", ".join(sorted(COMPONENT_PREFIXES))
|
||||
print(f"Unknown component: {component!r}. Available: {available}")
|
||||
sys.exit(1)
|
||||
component_prefixes = COMPONENT_PREFIXES[component_lower]
|
||||
|
||||
has_filters = (
|
||||
min_level is not None
|
||||
or session is not None
|
||||
or since_dt is not None
|
||||
or component_prefixes is not None
|
||||
)
|
||||
|
||||
# Read and display the tail
|
||||
try:
|
||||
lines = _read_tail(log_path, num_lines, has_filters=has_filters,
|
||||
min_level=min_level, session_filter=session,
|
||||
since=since_dt)
|
||||
since=since_dt, component_prefixes=component_prefixes)
|
||||
except PermissionError:
|
||||
print(f"Permission denied: {log_path}")
|
||||
sys.exit(1)
|
||||
@@ -172,6 +221,8 @@ def tail_log(
|
||||
filter_parts.append(f"level>={min_level}")
|
||||
if session:
|
||||
filter_parts.append(f"session={session}")
|
||||
if component:
|
||||
filter_parts.append(f"component={component}")
|
||||
if since:
|
||||
filter_parts.append(f"since={since}")
|
||||
filter_desc = f" [{', '.join(filter_parts)}]" if filter_parts else ""
|
||||
@@ -190,7 +241,7 @@ def tail_log(
|
||||
# Follow mode — poll for new content
|
||||
try:
|
||||
_follow_log(log_path, min_level=min_level, session_filter=session,
|
||||
since=since_dt)
|
||||
since=since_dt, component_prefixes=component_prefixes)
|
||||
except KeyboardInterrupt:
|
||||
print("\n--- stopped ---")
|
||||
|
||||
@@ -203,6 +254,7 @@ def _read_tail(
|
||||
min_level: Optional[str] = None,
|
||||
session_filter: Optional[str] = None,
|
||||
since: Optional[datetime] = None,
|
||||
component_prefixes: Optional[Sequence[str]] = None,
|
||||
) -> list:
|
||||
"""Read the last *num_lines* matching lines from a log file.
|
||||
|
||||
@@ -215,7 +267,8 @@ def _read_tail(
|
||||
filtered = [
|
||||
l for l in raw_lines
|
||||
if _matches_filters(l, min_level=min_level,
|
||||
session_filter=session_filter, since=since)
|
||||
session_filter=session_filter, since=since,
|
||||
component_prefixes=component_prefixes)
|
||||
]
|
||||
return filtered[-num_lines:]
|
||||
else:
|
||||
@@ -284,6 +337,7 @@ def _follow_log(
|
||||
min_level: Optional[str] = None,
|
||||
session_filter: Optional[str] = None,
|
||||
since: Optional[datetime] = None,
|
||||
component_prefixes: Optional[Sequence[str]] = None,
|
||||
) -> None:
|
||||
"""Poll a log file for new content and print matching lines."""
|
||||
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
||||
@@ -293,7 +347,8 @@ def _follow_log(
|
||||
line = f.readline()
|
||||
if line:
|
||||
if _matches_filters(line, min_level=min_level,
|
||||
session_filter=session_filter, since=since):
|
||||
session_filter=session_filter, since=since,
|
||||
component_prefixes=component_prefixes):
|
||||
print(line, end="")
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
|
||||
@@ -4338,6 +4338,7 @@ def cmd_logs(args):
|
||||
level=getattr(args, "level", None),
|
||||
session=getattr(args, "session", None),
|
||||
since=getattr(args, "since", None),
|
||||
component=getattr(args, "component", None),
|
||||
)
|
||||
|
||||
|
||||
@@ -5737,6 +5738,7 @@ Examples:
|
||||
hermes logs gateway -n 100 Show last 100 lines of gateway.log
|
||||
hermes logs --level WARNING Only show WARNING and above
|
||||
hermes logs --session abc123 Filter by session ID
|
||||
hermes logs --component tools Only show tool-related lines
|
||||
hermes logs --since 1h Lines from the last hour
|
||||
hermes logs --since 30m -f Follow, starting from 30 min ago
|
||||
hermes logs list List available log files with sizes
|
||||
@@ -5766,6 +5768,10 @@ Examples:
|
||||
"--since", metavar="TIME",
|
||||
help="Show lines since TIME ago (e.g. 1h, 30m, 2d)",
|
||||
)
|
||||
logs_parser.add_argument(
|
||||
"--component", metavar="NAME",
|
||||
help="Filter by component: gateway, agent, tools, cli, cron",
|
||||
)
|
||||
logs_parser.set_defaults(func=cmd_logs)
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -7,16 +7,28 @@ gateway call early in their startup path. All log files live under
|
||||
Log files produced:
|
||||
agent.log — INFO+, all agent/tool/session activity (the main log)
|
||||
errors.log — WARNING+, errors and warnings only (quick triage)
|
||||
gateway.log — INFO+, gateway-only events (created when mode="gateway")
|
||||
|
||||
Both files use ``RotatingFileHandler`` with ``RedactingFormatter`` so
|
||||
All files use ``RotatingFileHandler`` with ``RedactingFormatter`` so
|
||||
secrets are never written to disk.
|
||||
|
||||
Component separation:
|
||||
gateway.log only receives records from ``gateway.*`` loggers —
|
||||
platform adapters, session management, slash commands, delivery.
|
||||
agent.log remains the catch-all (everything goes there).
|
||||
|
||||
Session context:
|
||||
Call ``set_session_context(session_id)`` at the start of a conversation
|
||||
and ``clear_session_context()`` when done. All log lines emitted on
|
||||
that thread will include ``[session_id]`` for filtering/correlation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from hermes_constants import get_config_path, get_hermes_home
|
||||
|
||||
@@ -25,9 +37,14 @@ from hermes_constants import get_config_path, get_hermes_home
|
||||
# unless ``force=True``.
|
||||
_logging_initialized = False
|
||||
|
||||
# Default log format — includes timestamp, level, logger name, and message.
|
||||
_LOG_FORMAT = "%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||
_LOG_FORMAT_VERBOSE = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
# Thread-local storage for per-conversation session context.
|
||||
_session_context = threading.local()
|
||||
|
||||
# Default log format — includes timestamp, level, optional session tag,
|
||||
# logger name, and message. The ``%(session_tag)s`` field is guaranteed to
|
||||
# exist on every LogRecord via _install_session_record_factory() below.
|
||||
_LOG_FORMAT = "%(asctime)s %(levelname)s%(session_tag)s %(name)s: %(message)s"
|
||||
_LOG_FORMAT_VERBOSE = "%(asctime)s - %(name)s - %(levelname)s%(session_tag)s - %(message)s"
|
||||
|
||||
# Third-party loggers that are noisy at DEBUG/INFO level.
|
||||
_NOISY_LOGGERS = (
|
||||
@@ -48,6 +65,99 @@ _NOISY_LOGGERS = (
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public session context API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def set_session_context(session_id: str) -> None:
|
||||
"""Set the session ID for the current thread.
|
||||
|
||||
All subsequent log records on this thread will include ``[session_id]``
|
||||
in the formatted output. Call at the start of ``run_conversation()``.
|
||||
"""
|
||||
_session_context.session_id = session_id
|
||||
|
||||
|
||||
def clear_session_context() -> None:
|
||||
"""Clear the session ID for the current thread.
|
||||
|
||||
Optional — ``set_session_context()`` overwrites the previous value,
|
||||
so explicit clearing is only needed if the thread is reused for
|
||||
non-conversation work after ``run_conversation()`` returns.
|
||||
"""
|
||||
_session_context.session_id = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Record factory — injects session_tag into every LogRecord at creation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _install_session_record_factory() -> None:
|
||||
"""Replace the global LogRecord factory with one that adds ``session_tag``.
|
||||
|
||||
Unlike a ``logging.Filter`` on a handler or logger, the record factory
|
||||
runs for EVERY record in the process — including records that propagate
|
||||
from child loggers and records handled by third-party handlers. This
|
||||
guarantees ``%(session_tag)s`` is always available in format strings,
|
||||
eliminating the KeyError that would occur if a handler used our format
|
||||
without having a ``_SessionFilter`` attached.
|
||||
|
||||
Idempotent — checks for a marker attribute to avoid double-wrapping if
|
||||
the module is reloaded.
|
||||
"""
|
||||
current_factory = logging.getLogRecordFactory()
|
||||
if getattr(current_factory, "_hermes_session_injector", False):
|
||||
return # already installed
|
||||
|
||||
def _session_record_factory(*args, **kwargs):
|
||||
record = current_factory(*args, **kwargs)
|
||||
sid = getattr(_session_context, "session_id", None)
|
||||
record.session_tag = f" [{sid}]" if sid else "" # type: ignore[attr-defined]
|
||||
return record
|
||||
|
||||
_session_record_factory._hermes_session_injector = True # type: ignore[attr-defined]
|
||||
logging.setLogRecordFactory(_session_record_factory)
|
||||
|
||||
|
||||
# Install immediately on import — session_tag is available on all records
|
||||
# from this point forward, even before setup_logging() is called.
|
||||
_install_session_record_factory()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _ComponentFilter(logging.Filter):
|
||||
"""Only pass records whose logger name starts with one of *prefixes*.
|
||||
|
||||
Used to route gateway-specific records to ``gateway.log`` while
|
||||
keeping ``agent.log`` as the catch-all.
|
||||
"""
|
||||
|
||||
def __init__(self, prefixes: Sequence[str]) -> None:
|
||||
super().__init__()
|
||||
self._prefixes = tuple(prefixes)
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return record.name.startswith(self._prefixes)
|
||||
|
||||
|
||||
# Logger name prefixes that belong to each component.
|
||||
# Used by _ComponentFilter and exposed for ``hermes logs --component``.
|
||||
COMPONENT_PREFIXES = {
|
||||
"gateway": ("gateway",),
|
||||
"agent": ("agent", "run_agent", "model_tools", "batch_runner"),
|
||||
"tools": ("tools",),
|
||||
"cli": ("hermes_cli", "cli"),
|
||||
"cron": ("cron",),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def setup_logging(
|
||||
*,
|
||||
hermes_home: Optional[Path] = None,
|
||||
@@ -78,8 +188,9 @@ def setup_logging(
|
||||
Number of rotated backup files to keep.
|
||||
Defaults to 3 or the value from config.yaml ``logging.backup_count``.
|
||||
mode
|
||||
Hint for the caller context: ``"cli"``, ``"gateway"``, ``"cron"``.
|
||||
Currently used only for log format tuning (gateway includes PID).
|
||||
Caller context: ``"cli"``, ``"gateway"``, ``"cron"``.
|
||||
When ``"gateway"``, an additional ``gateway.log`` file is created
|
||||
that receives only gateway-component records.
|
||||
force
|
||||
Re-run setup even if it has already been called.
|
||||
|
||||
@@ -130,6 +241,18 @@ def setup_logging(
|
||||
formatter=RedactingFormatter(_LOG_FORMAT),
|
||||
)
|
||||
|
||||
# --- gateway.log (INFO+, gateway component only) ------------------------
|
||||
if mode == "gateway":
|
||||
_add_rotating_handler(
|
||||
root,
|
||||
log_dir / "gateway.log",
|
||||
level=logging.INFO,
|
||||
max_bytes=5 * 1024 * 1024,
|
||||
backup_count=3,
|
||||
formatter=RedactingFormatter(_LOG_FORMAT),
|
||||
log_filter=_ComponentFilter(COMPONENT_PREFIXES["gateway"]),
|
||||
)
|
||||
|
||||
# Ensure root logger level is low enough for the handlers to fire.
|
||||
if root.level == logging.NOTSET or root.level > level:
|
||||
root.setLevel(level)
|
||||
@@ -218,9 +341,16 @@ def _add_rotating_handler(
|
||||
max_bytes: int,
|
||||
backup_count: int,
|
||||
formatter: logging.Formatter,
|
||||
log_filter: Optional[logging.Filter] = None,
|
||||
) -> None:
|
||||
"""Add a ``RotatingFileHandler`` to *logger*, skipping if one already
|
||||
exists for the same resolved file path (idempotent).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
log_filter
|
||||
Optional filter to attach to the handler (e.g. ``_ComponentFilter``
|
||||
for gateway.log).
|
||||
"""
|
||||
resolved = path.resolve()
|
||||
for existing in logger.handlers:
|
||||
@@ -236,6 +366,8 @@ def _add_rotating_handler(
|
||||
)
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(formatter)
|
||||
if log_filter is not None:
|
||||
handler.addFilter(log_filter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
|
||||
@@ -7529,6 +7529,11 @@ class AIAgent:
|
||||
# Installed once, transparent when streams are healthy, prevents crash on write.
|
||||
_install_safe_stdio()
|
||||
|
||||
# Tag all log records on this thread with the session ID so
|
||||
# ``hermes logs --session <id>`` can filter a single conversation.
|
||||
from hermes_logging import set_session_context
|
||||
set_session_context(self.session_id)
|
||||
|
||||
# If the previous turn activated fallback, restore the primary
|
||||
# runtime so this turn gets a fresh attempt with the preferred model.
|
||||
# No-op when _fallback_activated is False (gateway, first turn, etc.).
|
||||
|
||||
@@ -1,288 +1,255 @@
|
||||
"""Tests for hermes_cli/logs.py — log viewing and filtering."""
|
||||
"""Tests for hermes_cli.logs — log viewing and filtering."""
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
from datetime import datetime, timedelta
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.logs import (
|
||||
LOG_FILES,
|
||||
_extract_level,
|
||||
_extract_logger_name,
|
||||
_line_matches_component,
|
||||
_matches_filters,
|
||||
_parse_line_timestamp,
|
||||
_parse_since,
|
||||
_read_last_n_lines,
|
||||
list_logs,
|
||||
tail_log,
|
||||
_read_tail,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def log_dir(tmp_path, monkeypatch):
|
||||
"""Create a fake HERMES_HOME with a logs/ directory."""
|
||||
home = Path(os.environ["HERMES_HOME"])
|
||||
logs = home / "logs"
|
||||
logs.mkdir(parents=True, exist_ok=True)
|
||||
return logs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_log(log_dir):
|
||||
"""Write a realistic agent.log with mixed levels and sessions."""
|
||||
lines = textwrap.dedent("""\
|
||||
2026-04-05 10:00:00,000 INFO run_agent: conversation turn: session=sess_aaa model=claude provider=openrouter platform=cli history=0 msg='hello'
|
||||
2026-04-05 10:00:01,000 INFO run_agent: tool terminal completed (0.50s, 200 chars)
|
||||
2026-04-05 10:00:02,000 INFO run_agent: API call #1: model=claude provider=openrouter in=1000 out=200 total=1200 latency=1.5s
|
||||
2026-04-05 10:00:03,000 WARNING run_agent: Tool web_search returned error (2.00s): timeout
|
||||
2026-04-05 10:00:04,000 INFO run_agent: conversation turn: session=sess_bbb model=gpt-5 provider=openai platform=telegram history=5 msg='fix bug'
|
||||
2026-04-05 10:00:05,000 ERROR run_agent: API call failed after 3 retries. rate limited
|
||||
2026-04-05 10:00:06,000 INFO run_agent: tool read_file completed (0.01s, 500 chars)
|
||||
2026-04-05 10:00:07,000 DEBUG run_agent: verbose internal detail
|
||||
2026-04-05 10:00:08,000 INFO credential_pool: credential pool: marking key-1 exhausted (status=429), rotating
|
||||
2026-04-05 10:00:09,000 INFO credential_pool: credential pool: rotated to key-2
|
||||
""")
|
||||
path = log_dir / "agent.log"
|
||||
path.write_text(lines)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_errors_log(log_dir):
|
||||
"""Write a small errors.log."""
|
||||
lines = textwrap.dedent("""\
|
||||
2026-04-05 10:00:03,000 WARNING run_agent: Tool web_search returned error (2.00s): timeout
|
||||
2026-04-05 10:00:05,000 ERROR run_agent: API call failed after 3 retries. rate limited
|
||||
""")
|
||||
path = log_dir / "errors.log"
|
||||
path.write_text(lines)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_since
|
||||
# Timestamp parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseSince:
|
||||
def test_hours(self):
|
||||
cutoff = _parse_since("2h")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(7200, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 7200) < 2
|
||||
|
||||
def test_minutes(self):
|
||||
cutoff = _parse_since("30m")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(1800, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 1800) < 2
|
||||
|
||||
def test_days(self):
|
||||
cutoff = _parse_since("1d")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(86400, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 86400) < 2
|
||||
|
||||
def test_seconds(self):
|
||||
cutoff = _parse_since("60s")
|
||||
cutoff = _parse_since("120s")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(60, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 120) < 2
|
||||
|
||||
def test_invalid_returns_none(self):
|
||||
assert _parse_since("abc") is None
|
||||
assert _parse_since("") is None
|
||||
assert _parse_since("10x") is None
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
cutoff = _parse_since(" 1h ")
|
||||
def test_whitespace_tolerance(self):
|
||||
cutoff = _parse_since(" 5m ")
|
||||
assert cutoff is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_line_timestamp
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseLineTimestamp:
|
||||
def test_standard_format(self):
|
||||
ts = _parse_line_timestamp("2026-04-05 10:00:00,123 INFO something")
|
||||
assert ts is not None
|
||||
assert ts.year == 2026
|
||||
assert ts.hour == 10
|
||||
ts = _parse_line_timestamp("2026-04-11 10:23:45 INFO gateway.run: msg")
|
||||
assert ts == datetime(2026, 4, 11, 10, 23, 45)
|
||||
|
||||
def test_no_timestamp(self):
|
||||
assert _parse_line_timestamp("just some text") is None
|
||||
assert _parse_line_timestamp("no timestamp here") is None
|
||||
|
||||
def test_continuation_line(self):
|
||||
assert _parse_line_timestamp(" at module.function (line 42)") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_level
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractLevel:
|
||||
def test_info(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 INFO run_agent: something") == "INFO"
|
||||
assert _extract_level("2026-01-01 00:00:00 INFO gateway.run: msg") == "INFO"
|
||||
|
||||
def test_warning(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 WARNING run_agent: bad") == "WARNING"
|
||||
assert _extract_level("2026-01-01 00:00:00 WARNING tools.file: msg") == "WARNING"
|
||||
|
||||
def test_error(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 ERROR run_agent: crash") == "ERROR"
|
||||
assert _extract_level("2026-01-01 00:00:00 ERROR run_agent: msg") == "ERROR"
|
||||
|
||||
def test_debug(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 DEBUG run_agent: detail") == "DEBUG"
|
||||
assert _extract_level("2026-01-01 00:00:00 DEBUG agent.aux: msg") == "DEBUG"
|
||||
|
||||
def test_no_level(self):
|
||||
assert _extract_level("just a plain line") is None
|
||||
assert _extract_level("random text") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _matches_filters
|
||||
# Logger name extraction (new for component filtering)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractLoggerName:
|
||||
def test_standard_line(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.run: Starting gateway"
|
||||
assert _extract_logger_name(line) == "gateway.run"
|
||||
|
||||
def test_nested_logger(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.platforms.telegram: connected"
|
||||
assert _extract_logger_name(line) == "gateway.platforms.telegram"
|
||||
|
||||
def test_warning_level(self):
|
||||
line = "2026-04-11 10:23:45 WARNING tools.terminal_tool: timeout"
|
||||
assert _extract_logger_name(line) == "tools.terminal_tool"
|
||||
|
||||
def test_with_session_tag(self):
|
||||
line = "2026-04-11 10:23:45 INFO [abc123] tools.file_tools: reading file"
|
||||
assert _extract_logger_name(line) == "tools.file_tools"
|
||||
|
||||
def test_with_session_tag_and_error(self):
|
||||
line = "2026-04-11 10:23:45 ERROR [sess_xyz] agent.context_compressor: failed"
|
||||
assert _extract_logger_name(line) == "agent.context_compressor"
|
||||
|
||||
def test_top_level_module(self):
|
||||
line = "2026-04-11 10:23:45 INFO run_agent: starting conversation"
|
||||
assert _extract_logger_name(line) == "run_agent"
|
||||
|
||||
def test_no_match(self):
|
||||
assert _extract_logger_name("random text") is None
|
||||
|
||||
|
||||
class TestLineMatchesComponent:
|
||||
def test_gateway_component(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.run: msg"
|
||||
assert _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_gateway_nested(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.platforms.telegram: msg"
|
||||
assert _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_tools_component(self):
|
||||
line = "2026-04-11 10:23:45 INFO tools.terminal_tool: msg"
|
||||
assert _line_matches_component(line, ("tools",))
|
||||
|
||||
def test_agent_with_multiple_prefixes(self):
|
||||
prefixes = ("agent", "run_agent", "model_tools")
|
||||
assert _line_matches_component(
|
||||
"2026-04-11 10:23:45 INFO agent.context_compressor: msg", prefixes)
|
||||
assert _line_matches_component(
|
||||
"2026-04-11 10:23:45 INFO run_agent: msg", prefixes)
|
||||
assert _line_matches_component(
|
||||
"2026-04-11 10:23:45 INFO model_tools: msg", prefixes)
|
||||
|
||||
def test_no_match(self):
|
||||
line = "2026-04-11 10:23:45 INFO tools.browser: msg"
|
||||
assert not _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_with_session_tag(self):
|
||||
line = "2026-04-11 10:23:45 INFO [abc] gateway.run: msg"
|
||||
assert _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_unparseable_line(self):
|
||||
assert not _line_matches_component("random text", ("gateway",))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Combined filter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatchesFilters:
|
||||
def test_no_filters_always_matches(self):
|
||||
assert _matches_filters("any line") is True
|
||||
def test_no_filters_passes_everything(self):
|
||||
assert _matches_filters("any line")
|
||||
|
||||
def test_level_filter_passes(self):
|
||||
def test_level_filter(self):
|
||||
assert _matches_filters(
|
||||
"2026-04-05 10:00:00 WARNING something",
|
||||
min_level="WARNING",
|
||||
) is True
|
||||
"2026-01-01 00:00:00 WARNING x: msg", min_level="WARNING")
|
||||
assert not _matches_filters(
|
||||
"2026-01-01 00:00:00 INFO x: msg", min_level="WARNING")
|
||||
|
||||
def test_level_filter_rejects(self):
|
||||
def test_session_filter(self):
|
||||
assert _matches_filters(
|
||||
"2026-04-05 10:00:00 INFO something",
|
||||
min_level="WARNING",
|
||||
) is False
|
||||
"2026-01-01 00:00:00 INFO [abc123] x: msg", session_filter="abc123")
|
||||
assert not _matches_filters(
|
||||
"2026-01-01 00:00:00 INFO [xyz789] x: msg", session_filter="abc123")
|
||||
|
||||
def test_session_filter_passes(self):
|
||||
def test_component_filter(self):
|
||||
assert _matches_filters(
|
||||
"session=sess_aaa model=claude",
|
||||
session_filter="sess_aaa",
|
||||
) is True
|
||||
|
||||
def test_session_filter_rejects(self):
|
||||
assert _matches_filters(
|
||||
"session=sess_aaa model=claude",
|
||||
session_filter="sess_bbb",
|
||||
) is False
|
||||
|
||||
def test_since_filter_passes(self):
|
||||
# Line from the future should always pass
|
||||
assert _matches_filters(
|
||||
"2099-01-01 00:00:00 INFO future",
|
||||
since=datetime.now(),
|
||||
) is True
|
||||
|
||||
def test_since_filter_rejects(self):
|
||||
assert _matches_filters(
|
||||
"2020-01-01 00:00:00 INFO past",
|
||||
since=datetime.now(),
|
||||
) is False
|
||||
"2026-01-01 00:00:00 INFO gateway.run: msg",
|
||||
component_prefixes=("gateway",))
|
||||
assert not _matches_filters(
|
||||
"2026-01-01 00:00:00 INFO tools.file: msg",
|
||||
component_prefixes=("gateway",))
|
||||
|
||||
def test_combined_filters(self):
|
||||
line = "2099-01-01 00:00:00 WARNING run_agent: session=abc error"
|
||||
"""All filters must pass for a line to match."""
|
||||
line = "2026-04-11 10:00:00 WARNING [sess_1] gateway.run: connection lost"
|
||||
assert _matches_filters(
|
||||
line, min_level="WARNING", session_filter="abc",
|
||||
since=datetime.now(),
|
||||
) is True
|
||||
# Fails session filter
|
||||
line,
|
||||
min_level="WARNING",
|
||||
session_filter="sess_1",
|
||||
component_prefixes=("gateway",),
|
||||
)
|
||||
# Fails component filter
|
||||
assert not _matches_filters(
|
||||
line,
|
||||
min_level="WARNING",
|
||||
session_filter="sess_1",
|
||||
component_prefixes=("tools",),
|
||||
)
|
||||
|
||||
def test_since_filter(self):
|
||||
# Line with a very old timestamp should be filtered out
|
||||
assert not _matches_filters(
|
||||
"2020-01-01 00:00:00 INFO x: old msg",
|
||||
since=datetime.now() - timedelta(hours=1))
|
||||
# Line with a recent timestamp should pass
|
||||
recent = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert _matches_filters(
|
||||
line, min_level="WARNING", session_filter="xyz",
|
||||
) is False
|
||||
f"{recent} INFO x: recent msg",
|
||||
since=datetime.now() - timedelta(hours=1))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_last_n_lines
|
||||
# File reading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadLastNLines:
|
||||
def test_reads_correct_count(self, sample_agent_log):
|
||||
lines = _read_last_n_lines(sample_agent_log, 3)
|
||||
assert len(lines) == 3
|
||||
class TestReadTail:
|
||||
def test_read_small_file(self, tmp_path):
|
||||
log_file = tmp_path / "test.log"
|
||||
lines = [f"2026-01-01 00:00:0{i} INFO x: line {i}\n" for i in range(10)]
|
||||
log_file.write_text("".join(lines))
|
||||
|
||||
def test_reads_all_when_fewer(self, sample_agent_log):
|
||||
lines = _read_last_n_lines(sample_agent_log, 100)
|
||||
assert len(lines) == 10 # sample has 10 lines
|
||||
result = _read_last_n_lines(log_file, 5)
|
||||
assert len(result) == 5
|
||||
assert "line 9" in result[-1]
|
||||
|
||||
def test_empty_file(self, log_dir):
|
||||
empty = log_dir / "empty.log"
|
||||
empty.write_text("")
|
||||
lines = _read_last_n_lines(empty, 10)
|
||||
assert lines == []
|
||||
def test_read_with_component_filter(self, tmp_path):
|
||||
log_file = tmp_path / "test.log"
|
||||
lines = [
|
||||
"2026-01-01 00:00:00 INFO gateway.run: gw msg\n",
|
||||
"2026-01-01 00:00:01 INFO tools.file: tool msg\n",
|
||||
"2026-01-01 00:00:02 INFO gateway.session: session msg\n",
|
||||
"2026-01-01 00:00:03 INFO agent.compressor: agent msg\n",
|
||||
]
|
||||
log_file.write_text("".join(lines))
|
||||
|
||||
def test_last_line_content(self, sample_agent_log):
|
||||
lines = _read_last_n_lines(sample_agent_log, 1)
|
||||
assert "rotated to key-2" in lines[0]
|
||||
result = _read_tail(
|
||||
log_file, 50,
|
||||
has_filters=True,
|
||||
component_prefixes=("gateway",),
|
||||
)
|
||||
assert len(result) == 2
|
||||
assert "gw msg" in result[0]
|
||||
assert "session msg" in result[1]
|
||||
|
||||
def test_empty_file(self, tmp_path):
|
||||
log_file = tmp_path / "empty.log"
|
||||
log_file.write_text("")
|
||||
result = _read_last_n_lines(log_file, 10)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tail_log
|
||||
# LOG_FILES registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTailLog:
|
||||
def test_basic_tail(self, sample_agent_log, capsys):
|
||||
tail_log("agent", num_lines=3)
|
||||
captured = capsys.readouterr()
|
||||
assert "agent.log" in captured.out
|
||||
# Should have the header + 3 lines
|
||||
lines = captured.out.strip().split("\n")
|
||||
assert len(lines) == 4 # 1 header + 3 content
|
||||
|
||||
def test_level_filter(self, sample_agent_log, capsys):
|
||||
tail_log("agent", num_lines=50, level="ERROR")
|
||||
captured = capsys.readouterr()
|
||||
assert "level>=ERROR" in captured.out
|
||||
# Only the ERROR line should appear
|
||||
content_lines = [l for l in captured.out.strip().split("\n") if not l.startswith("---")]
|
||||
assert len(content_lines) == 1
|
||||
assert "API call failed" in content_lines[0]
|
||||
|
||||
def test_session_filter(self, sample_agent_log, capsys):
|
||||
tail_log("agent", num_lines=50, session="sess_bbb")
|
||||
captured = capsys.readouterr()
|
||||
content_lines = [l for l in captured.out.strip().split("\n") if not l.startswith("---")]
|
||||
assert len(content_lines) == 1
|
||||
assert "sess_bbb" in content_lines[0]
|
||||
|
||||
def test_errors_log(self, sample_errors_log, capsys):
|
||||
tail_log("errors", num_lines=10)
|
||||
captured = capsys.readouterr()
|
||||
assert "errors.log" in captured.out
|
||||
assert "WARNING" in captured.out or "ERROR" in captured.out
|
||||
|
||||
def test_unknown_log_exits(self):
|
||||
with pytest.raises(SystemExit):
|
||||
tail_log("nonexistent")
|
||||
|
||||
def test_missing_file_exits(self, log_dir):
|
||||
with pytest.raises(SystemExit):
|
||||
tail_log("agent") # agent.log doesn't exist in clean log_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListLogs:
|
||||
def test_lists_files(self, sample_agent_log, sample_errors_log, capsys):
|
||||
list_logs()
|
||||
captured = capsys.readouterr()
|
||||
assert "agent.log" in captured.out
|
||||
assert "errors.log" in captured.out
|
||||
|
||||
def test_empty_dir(self, log_dir, capsys):
|
||||
list_logs()
|
||||
captured = capsys.readouterr()
|
||||
assert "no log files yet" in captured.out
|
||||
|
||||
def test_shows_sizes(self, sample_agent_log, capsys):
|
||||
list_logs()
|
||||
captured = capsys.readouterr()
|
||||
# File is small, should show as bytes or KB
|
||||
assert "B" in captured.out or "KB" in captured.out
|
||||
class TestLogFiles:
|
||||
def test_known_log_files(self):
|
||||
assert "agent" in LOG_FILES
|
||||
assert "errors" in LOG_FILES
|
||||
assert "gateway" in LOG_FILES
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
import threading
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
@@ -34,6 +35,8 @@ def _reset_logging_state():
|
||||
h.close()
|
||||
else:
|
||||
pre_existing.append(h)
|
||||
# Ensure the record factory is installed (it's idempotent).
|
||||
hermes_logging._install_session_record_factory()
|
||||
yield
|
||||
# Restore — remove any handlers added during the test.
|
||||
for h in list(root.handlers):
|
||||
@@ -41,6 +44,7 @@ def _reset_logging_state():
|
||||
root.removeHandler(h)
|
||||
h.close()
|
||||
hermes_logging._logging_initialized = False
|
||||
hermes_logging.clear_session_context()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -220,6 +224,294 @@ class TestSetupLogging:
|
||||
]
|
||||
assert agent_handlers[0].level == logging.WARNING
|
||||
|
||||
def test_record_factory_installed(self, hermes_home):
|
||||
"""The custom record factory injects session_tag on all records."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
factory = logging.getLogRecordFactory()
|
||||
assert getattr(factory, "_hermes_session_injector", False), (
|
||||
"Record factory should have _hermes_session_injector marker"
|
||||
)
|
||||
# Verify session_tag exists on a fresh record
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert hasattr(record, "session_tag")
|
||||
|
||||
|
||||
class TestGatewayMode:
|
||||
"""setup_logging(mode='gateway') creates a filtered gateway.log."""
|
||||
|
||||
def test_gateway_log_created(self, hermes_home):
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
root = logging.getLogger()
|
||||
|
||||
gw_handlers = [
|
||||
h for h in root.handlers
|
||||
if isinstance(h, RotatingFileHandler)
|
||||
and "gateway.log" in getattr(h, "baseFilename", "")
|
||||
]
|
||||
assert len(gw_handlers) == 1
|
||||
|
||||
def test_gateway_log_not_created_in_cli_mode(self, hermes_home):
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="cli")
|
||||
root = logging.getLogger()
|
||||
|
||||
gw_handlers = [
|
||||
h for h in root.handlers
|
||||
if isinstance(h, RotatingFileHandler)
|
||||
and "gateway.log" in getattr(h, "baseFilename", "")
|
||||
]
|
||||
assert len(gw_handlers) == 0
|
||||
|
||||
def test_gateway_log_receives_gateway_records(self, hermes_home):
|
||||
"""gateway.log captures records from gateway.* loggers."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
gw_logger = logging.getLogger("gateway.platforms.telegram")
|
||||
gw_logger.info("telegram connected")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
gw_log = hermes_home / "logs" / "gateway.log"
|
||||
assert gw_log.exists()
|
||||
assert "telegram connected" in gw_log.read_text()
|
||||
|
||||
def test_gateway_log_rejects_non_gateway_records(self, hermes_home):
|
||||
"""gateway.log does NOT capture records from tools.*, agent.*, etc."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
tool_logger = logging.getLogger("tools.terminal_tool")
|
||||
tool_logger.info("running command")
|
||||
|
||||
agent_logger = logging.getLogger("agent.context_compressor")
|
||||
agent_logger.info("compressing context")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
gw_log = hermes_home / "logs" / "gateway.log"
|
||||
if gw_log.exists():
|
||||
content = gw_log.read_text()
|
||||
assert "running command" not in content
|
||||
assert "compressing context" not in content
|
||||
|
||||
def test_agent_log_still_receives_all(self, hermes_home):
|
||||
"""agent.log (catch-all) still receives gateway AND tool records."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
logging.getLogger("gateway.run").info("gateway msg")
|
||||
logging.getLogger("tools.file_tools").info("file msg")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "gateway msg" in content
|
||||
assert "file msg" in content
|
||||
|
||||
|
||||
class TestSessionContext:
|
||||
"""set_session_context / clear_session_context + _SessionFilter."""
|
||||
|
||||
def test_session_tag_in_log_output(self, hermes_home):
|
||||
"""When session context is set, log lines include [session_id]."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
hermes_logging.set_session_context("abc123")
|
||||
|
||||
test_logger = logging.getLogger("test.session_tag")
|
||||
test_logger.info("tagged message")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "[abc123]" in content
|
||||
assert "tagged message" in content
|
||||
|
||||
def test_no_session_tag_without_context(self, hermes_home):
|
||||
"""Without session context, log lines have no session tag."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
hermes_logging.clear_session_context()
|
||||
|
||||
test_logger = logging.getLogger("test.no_session")
|
||||
test_logger.info("untagged message")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "untagged message" in content
|
||||
# Should not have any [xxx] session tag
|
||||
import re
|
||||
for line in content.splitlines():
|
||||
if "untagged message" in line:
|
||||
assert not re.search(r"\[.+?\]", line.split("INFO")[1].split("test.no_session")[0])
|
||||
|
||||
def test_clear_session_context(self, hermes_home):
|
||||
"""After clearing, session tag disappears."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
hermes_logging.set_session_context("xyz789")
|
||||
hermes_logging.clear_session_context()
|
||||
|
||||
test_logger = logging.getLogger("test.cleared")
|
||||
test_logger.info("after clear")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "[xyz789]" not in content
|
||||
|
||||
def test_session_context_thread_isolated(self, hermes_home):
|
||||
"""Session context is per-thread — one thread's context doesn't leak."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
|
||||
results = {}
|
||||
|
||||
def thread_a():
|
||||
hermes_logging.set_session_context("thread_a_session")
|
||||
logging.getLogger("test.thread_a").info("from thread A")
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
def thread_b():
|
||||
hermes_logging.set_session_context("thread_b_session")
|
||||
logging.getLogger("test.thread_b").info("from thread B")
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
ta = threading.Thread(target=thread_a)
|
||||
tb = threading.Thread(target=thread_b)
|
||||
ta.start()
|
||||
ta.join()
|
||||
tb.start()
|
||||
tb.join()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
|
||||
# Each thread's message should have its own session tag
|
||||
for line in content.splitlines():
|
||||
if "from thread A" in line:
|
||||
assert "[thread_a_session]" in line
|
||||
assert "[thread_b_session]" not in line
|
||||
if "from thread B" in line:
|
||||
assert "[thread_b_session]" in line
|
||||
assert "[thread_a_session]" not in line
|
||||
|
||||
|
||||
class TestRecordFactory:
|
||||
"""Unit tests for the custom LogRecord factory."""
|
||||
|
||||
def test_record_has_session_tag(self):
|
||||
"""Every record gets a session_tag attribute."""
|
||||
factory = logging.getLogRecordFactory()
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert hasattr(record, "session_tag")
|
||||
|
||||
def test_empty_tag_without_context(self):
|
||||
hermes_logging.clear_session_context()
|
||||
factory = logging.getLogRecordFactory()
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert record.session_tag == ""
|
||||
|
||||
def test_tag_with_context(self):
|
||||
hermes_logging.set_session_context("sess_42")
|
||||
factory = logging.getLogRecordFactory()
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert record.session_tag == " [sess_42]"
|
||||
|
||||
def test_idempotent_install(self):
|
||||
"""Calling _install_session_record_factory() twice doesn't double-wrap."""
|
||||
hermes_logging._install_session_record_factory()
|
||||
factory_a = logging.getLogRecordFactory()
|
||||
hermes_logging._install_session_record_factory()
|
||||
factory_b = logging.getLogRecordFactory()
|
||||
assert factory_a is factory_b
|
||||
|
||||
def test_works_with_any_handler(self):
|
||||
"""A handler using %(session_tag)s works even without _SessionFilter."""
|
||||
hermes_logging.set_session_context("any_handler_test")
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(session_tag)s %(message)s"))
|
||||
|
||||
logger = logging.getLogger("_test_any_handler")
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
try:
|
||||
# Should not raise KeyError
|
||||
logger.info("hello")
|
||||
finally:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
class TestComponentFilter:
|
||||
"""Unit tests for _ComponentFilter."""
|
||||
|
||||
def test_passes_matching_prefix(self):
|
||||
f = hermes_logging._ComponentFilter(("gateway",))
|
||||
record = logging.LogRecord(
|
||||
"gateway.run", logging.INFO, "", 0, "msg", (), None
|
||||
)
|
||||
assert f.filter(record) is True
|
||||
|
||||
def test_passes_nested_matching_prefix(self):
|
||||
f = hermes_logging._ComponentFilter(("gateway",))
|
||||
record = logging.LogRecord(
|
||||
"gateway.platforms.telegram", logging.INFO, "", 0, "msg", (), None
|
||||
)
|
||||
assert f.filter(record) is True
|
||||
|
||||
def test_blocks_non_matching(self):
|
||||
f = hermes_logging._ComponentFilter(("gateway",))
|
||||
record = logging.LogRecord(
|
||||
"tools.terminal_tool", logging.INFO, "", 0, "msg", (), None
|
||||
)
|
||||
assert f.filter(record) is False
|
||||
|
||||
def test_multiple_prefixes(self):
|
||||
f = hermes_logging._ComponentFilter(("agent", "run_agent", "model_tools"))
|
||||
assert f.filter(logging.LogRecord(
|
||||
"agent.compressor", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
assert f.filter(logging.LogRecord(
|
||||
"run_agent", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
assert f.filter(logging.LogRecord(
|
||||
"model_tools", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
assert not f.filter(logging.LogRecord(
|
||||
"tools.browser", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
|
||||
|
||||
class TestComponentPrefixes:
|
||||
"""COMPONENT_PREFIXES covers the expected components."""
|
||||
|
||||
def test_gateway_prefix(self):
|
||||
assert "gateway" in hermes_logging.COMPONENT_PREFIXES
|
||||
assert ("gateway",) == hermes_logging.COMPONENT_PREFIXES["gateway"]
|
||||
|
||||
def test_agent_prefix(self):
|
||||
prefixes = hermes_logging.COMPONENT_PREFIXES["agent"]
|
||||
assert "agent" in prefixes
|
||||
assert "run_agent" in prefixes
|
||||
assert "model_tools" in prefixes
|
||||
|
||||
def test_tools_prefix(self):
|
||||
assert ("tools",) == hermes_logging.COMPONENT_PREFIXES["tools"]
|
||||
|
||||
def test_cli_prefix(self):
|
||||
prefixes = hermes_logging.COMPONENT_PREFIXES["cli"]
|
||||
assert "hermes_cli" in prefixes
|
||||
assert "cli" in prefixes
|
||||
|
||||
def test_cron_prefix(self):
|
||||
assert ("cron",) == hermes_logging.COMPONENT_PREFIXES["cron"]
|
||||
|
||||
|
||||
class TestSetupVerboseLogging:
|
||||
"""setup_verbose_logging() adds a DEBUG-level console handler."""
|
||||
@@ -301,6 +593,59 @@ class TestAddRotatingHandler:
|
||||
logger.removeHandler(h)
|
||||
h.close()
|
||||
|
||||
def test_log_filter_attached(self, tmp_path):
|
||||
"""Optional log_filter is attached to the handler."""
|
||||
log_path = tmp_path / "filtered.log"
|
||||
logger = logging.getLogger("_test_rotating_filter")
|
||||
formatter = logging.Formatter("%(message)s")
|
||||
component_filter = hermes_logging._ComponentFilter(("test",))
|
||||
|
||||
hermes_logging._add_rotating_handler(
|
||||
logger, log_path,
|
||||
level=logging.INFO, max_bytes=1024, backup_count=1,
|
||||
formatter=formatter,
|
||||
log_filter=component_filter,
|
||||
)
|
||||
|
||||
handlers = [h for h in logger.handlers if isinstance(h, RotatingFileHandler)]
|
||||
assert len(handlers) == 1
|
||||
assert component_filter in handlers[0].filters
|
||||
# Clean up
|
||||
for h in list(logger.handlers):
|
||||
if isinstance(h, RotatingFileHandler):
|
||||
logger.removeHandler(h)
|
||||
h.close()
|
||||
|
||||
def test_no_session_filter_on_handler(self, tmp_path):
|
||||
"""Handlers rely on record factory, not per-handler _SessionFilter."""
|
||||
log_path = tmp_path / "no_session_filter.log"
|
||||
logger = logging.getLogger("_test_no_session_filter")
|
||||
formatter = logging.Formatter("%(session_tag)s%(message)s")
|
||||
|
||||
hermes_logging._add_rotating_handler(
|
||||
logger, log_path,
|
||||
level=logging.INFO, max_bytes=1024, backup_count=1,
|
||||
formatter=formatter,
|
||||
)
|
||||
|
||||
handlers = [h for h in logger.handlers if isinstance(h, RotatingFileHandler)]
|
||||
assert len(handlers) == 1
|
||||
# No _SessionFilter on the handler — record factory handles it
|
||||
assert len(handlers[0].filters) == 0
|
||||
|
||||
# But session_tag still works (via record factory)
|
||||
hermes_logging.set_session_context("factory_test")
|
||||
logger.info("test msg")
|
||||
handlers[0].flush()
|
||||
content = log_path.read_text()
|
||||
assert "[factory_test]" in content
|
||||
|
||||
# Clean up
|
||||
for h in list(logger.handlers):
|
||||
if isinstance(h, RotatingFileHandler):
|
||||
logger.removeHandler(h)
|
||||
h.close()
|
||||
|
||||
def test_managed_mode_initial_open_sets_group_writable(self, tmp_path):
|
||||
log_path = tmp_path / "managed-open.log"
|
||||
logger = logging.getLogger("_test_rotating_managed_open")
|
||||
|
||||
306
tests/tools/test_file_sync_back.py
Normal file
306
tests/tools/test_file_sync_back.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""Tests for FileSyncManager.sync_back() — pull remote changes to host."""
|
||||
|
||||
import fcntl
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.file_sync import (
|
||||
FileSyncManager,
|
||||
_sha256_file,
|
||||
_SYNC_BACK_BACKOFF,
|
||||
_SYNC_BACK_MAX_RETRIES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tar(files: dict[str, bytes], dest: Path):
|
||||
"""Write a tar archive containing the given arcname->content pairs."""
|
||||
with tarfile.open(dest, "w") as tar:
|
||||
for arcname, content in files.items():
|
||||
info = tarfile.TarInfo(name=arcname)
|
||||
info.size = len(content)
|
||||
tar.addfile(info, io.BytesIO(content))
|
||||
|
||||
|
||||
def _make_download_fn(files: dict[str, bytes]):
|
||||
"""Return a bulk_download_fn that writes a tar of the given files."""
|
||||
def download(dest: Path):
|
||||
_make_tar(files, dest)
|
||||
return download
|
||||
|
||||
|
||||
def _sha256_bytes(data: bytes) -> str:
|
||||
"""Compute SHA-256 hex digest of raw bytes (for test convenience)."""
|
||||
import hashlib
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def _write_file(path: Path, content: bytes) -> str:
|
||||
"""Write bytes to *path*, creating parents, and return the string path."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(content)
|
||||
return str(path)
|
||||
|
||||
|
||||
def _make_manager(
|
||||
tmp_path: Path,
|
||||
file_mapping: list[tuple[str, str]] | None = None,
|
||||
bulk_download_fn=None,
|
||||
) -> FileSyncManager:
|
||||
"""Create a FileSyncManager wired for testing.
|
||||
|
||||
*file_mapping* is a list of (host_path, remote_path) tuples that
|
||||
``get_files_fn`` returns. If *None* an empty list is used.
|
||||
"""
|
||||
mapping = file_mapping or []
|
||||
return FileSyncManager(
|
||||
get_files_fn=lambda: mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
bulk_download_fn=bulk_download_fn,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncBackNoop:
|
||||
"""sync_back() is a no-op when there is no download function."""
|
||||
|
||||
def test_sync_back_noop_without_download_fn(self, tmp_path):
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=None)
|
||||
# Should return immediately without error
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
# Nothing to assert beyond "no exception raised"
|
||||
|
||||
|
||||
class TestSyncBackNoChanges:
|
||||
"""When all remote files match pushed hashes, nothing is applied."""
|
||||
|
||||
def test_sync_back_no_changes(self, tmp_path):
|
||||
host_file = tmp_path / "host" / "cred.json"
|
||||
host_content = b'{"key": "val"}'
|
||||
_write_file(host_file, host_content)
|
||||
|
||||
remote_path = "/root/.hermes/cred.json"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
# Remote tar contains the same content as was pushed
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/cred.json": host_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
# Simulate that we already pushed this file with this hash
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(host_content)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# Host file should be unchanged (same content, same bytes)
|
||||
assert host_file.read_bytes() == host_content
|
||||
|
||||
|
||||
class TestSyncBackAppliesChanged:
|
||||
"""Remote file differs from pushed version -- gets copied to host."""
|
||||
|
||||
def test_sync_back_applies_changed_file(self, tmp_path):
|
||||
host_file = tmp_path / "host" / "skill.py"
|
||||
original_content = b"print('v1')"
|
||||
_write_file(host_file, original_content)
|
||||
|
||||
remote_path = "/root/.hermes/skill.py"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
remote_content = b"print('v2 - edited on remote')"
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/skill.py": remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
assert host_file.read_bytes() == remote_content
|
||||
|
||||
|
||||
class TestSyncBackNewRemoteFile:
|
||||
"""File created on remote (not in _pushed_hashes) is applied via _infer_host_path."""
|
||||
|
||||
def test_sync_back_detects_new_remote_file(self, tmp_path):
|
||||
# Existing mapping gives _infer_host_path a prefix to work with
|
||||
existing_host = tmp_path / "host" / "skills" / "existing.py"
|
||||
_write_file(existing_host, b"existing")
|
||||
mapping = [(str(existing_host), "/root/.hermes/skills/existing.py")]
|
||||
|
||||
# Remote has a NEW file in the same directory that was never pushed
|
||||
new_remote_content = b"# brand new skill created on remote"
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/skills/new_skill.py": new_remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
# No entry in _pushed_hashes for the new file
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# The new file should have been inferred and written to the host
|
||||
expected_host_path = tmp_path / "host" / "skills" / "new_skill.py"
|
||||
assert expected_host_path.exists()
|
||||
assert expected_host_path.read_bytes() == new_remote_content
|
||||
|
||||
|
||||
class TestSyncBackConflict:
|
||||
"""Host AND remote both changed since push -- warning logged, remote wins."""
|
||||
|
||||
def test_sync_back_conflict_warns(self, tmp_path, caplog):
|
||||
host_file = tmp_path / "host" / "config.json"
|
||||
original_content = b'{"v": 1}'
|
||||
_write_file(host_file, original_content)
|
||||
|
||||
remote_path = "/root/.hermes/config.json"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
# Host was modified after push
|
||||
host_file.write_bytes(b'{"v": 2, "host-edit": true}')
|
||||
|
||||
# Remote was also modified
|
||||
remote_content = b'{"v": 3, "remote-edit": true}'
|
||||
download_fn = _make_download_fn({
|
||||
"root/.hermes/config.json": remote_content,
|
||||
})
|
||||
|
||||
mgr = _make_manager(tmp_path, file_mapping=mapping, bulk_download_fn=download_fn)
|
||||
mgr._pushed_hashes[remote_path] = _sha256_bytes(original_content)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# Conflict warning was logged
|
||||
assert any("conflict" in r.message.lower() for r in caplog.records)
|
||||
|
||||
# Remote version wins (last-write-wins)
|
||||
assert host_file.read_bytes() == remote_content
|
||||
|
||||
|
||||
class TestSyncBackRetries:
|
||||
"""Retry behaviour with exponential backoff."""
|
||||
|
||||
@patch("tools.environments.file_sync.time.sleep")
|
||||
def test_sync_back_retries_on_failure(self, mock_sleep, tmp_path):
|
||||
call_count = 0
|
||||
|
||||
def flaky_download(dest: Path):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise RuntimeError(f"network error #{call_count}")
|
||||
# Third attempt succeeds -- write a valid (empty) tar
|
||||
_make_tar({}, dest)
|
||||
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=flaky_download)
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
assert call_count == 3
|
||||
# Sleep called twice (between attempt 1->2 and 2->3)
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[0])
|
||||
mock_sleep.assert_any_call(_SYNC_BACK_BACKOFF[1])
|
||||
|
||||
@patch("tools.environments.file_sync.time.sleep")
|
||||
def test_sync_back_all_retries_exhausted(self, mock_sleep, tmp_path, caplog):
|
||||
def always_fail(dest: Path):
|
||||
raise RuntimeError("persistent failure")
|
||||
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=always_fail)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="tools.environments.file_sync"):
|
||||
# Should NOT raise -- failures are logged, not propagated
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# All retries were attempted
|
||||
assert mock_sleep.call_count == _SYNC_BACK_MAX_RETRIES - 1
|
||||
|
||||
# Final "all attempts failed" warning was logged
|
||||
assert any("all" in r.message.lower() and "failed" in r.message.lower() for r in caplog.records)
|
||||
|
||||
|
||||
class TestPushedHashesPopulated:
|
||||
"""_pushed_hashes is populated during sync() and cleared on delete."""
|
||||
|
||||
def test_pushed_hashes_populated_on_sync(self, tmp_path):
|
||||
host_file = tmp_path / "data.txt"
|
||||
host_file.write_bytes(b"hello world")
|
||||
|
||||
remote_path = "/root/.hermes/data.txt"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
|
||||
mgr = FileSyncManager(
|
||||
get_files_fn=lambda: mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
)
|
||||
|
||||
mgr.sync(force=True)
|
||||
|
||||
assert remote_path in mgr._pushed_hashes
|
||||
assert mgr._pushed_hashes[remote_path] == _sha256_file(str(host_file))
|
||||
|
||||
def test_pushed_hashes_cleared_on_delete(self, tmp_path):
|
||||
host_file = tmp_path / "deleteme.txt"
|
||||
host_file.write_bytes(b"to be deleted")
|
||||
|
||||
remote_path = "/root/.hermes/deleteme.txt"
|
||||
mapping = [(str(host_file), remote_path)]
|
||||
current_mapping = list(mapping)
|
||||
|
||||
mgr = FileSyncManager(
|
||||
get_files_fn=lambda: current_mapping,
|
||||
upload_fn=MagicMock(),
|
||||
delete_fn=MagicMock(),
|
||||
)
|
||||
|
||||
# Sync to populate hashes
|
||||
mgr.sync(force=True)
|
||||
assert remote_path in mgr._pushed_hashes
|
||||
|
||||
# Remove the file from the mapping (simulates local deletion)
|
||||
os.unlink(str(host_file))
|
||||
current_mapping.clear()
|
||||
|
||||
mgr.sync(force=True)
|
||||
|
||||
# Hash should be cleaned up
|
||||
assert remote_path not in mgr._pushed_hashes
|
||||
|
||||
|
||||
class TestSyncBackFileLock:
|
||||
"""Verify that fcntl.flock is used during sync-back."""
|
||||
|
||||
@patch("tools.environments.file_sync.fcntl.flock")
|
||||
def test_sync_back_file_lock(self, mock_flock, tmp_path):
|
||||
download_fn = _make_download_fn({})
|
||||
mgr = _make_manager(tmp_path, bulk_download_fn=download_fn)
|
||||
|
||||
mgr.sync_back(hermes_home=tmp_path / ".hermes")
|
||||
|
||||
# flock should have been called at least twice: LOCK_EX to acquire, LOCK_UN to release
|
||||
assert mock_flock.call_count >= 2
|
||||
|
||||
lock_calls = mock_flock.call_args_list
|
||||
lock_ops = [c[0][1] for c in lock_calls]
|
||||
assert fcntl.LOCK_EX in lock_ops
|
||||
assert fcntl.LOCK_UN in lock_ops
|
||||
295
tests/tools/test_modal_bulk_upload.py
Normal file
295
tests/tools/test_modal_bulk_upload.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""Tests for Modal bulk upload via tar/base64 archive."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import modal as modal_env
|
||||
|
||||
|
||||
def _make_mock_modal_env(monkeypatch, tmp_path):
|
||||
"""Create a minimal mock ModalEnvironment for testing upload methods.
|
||||
|
||||
Returns a ModalEnvironment-like object with _sandbox and _worker mocked.
|
||||
We don't call __init__ because it requires the Modal SDK.
|
||||
"""
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
env._sync_manager = None
|
||||
return env
|
||||
|
||||
|
||||
def _make_mock_stdin():
|
||||
"""Create a mock stdin that captures written data."""
|
||||
stdin = MagicMock()
|
||||
written_chunks = []
|
||||
|
||||
def mock_write(data):
|
||||
written_chunks.append(data)
|
||||
|
||||
stdin.write = mock_write
|
||||
stdin.write_eof = MagicMock()
|
||||
stdin.drain = MagicMock()
|
||||
stdin.drain.aio = AsyncMock()
|
||||
stdin._written_chunks = written_chunks
|
||||
return stdin
|
||||
|
||||
|
||||
def _wire_async_exec(env, exec_calls=None):
|
||||
"""Wire mock sandbox.exec.aio and a real run_coroutine on the env.
|
||||
|
||||
Optionally captures exec call args into *exec_calls* list.
|
||||
Returns (exec_calls, run_kwargs, stdin_mock).
|
||||
"""
|
||||
if exec_calls is None:
|
||||
exec_calls = []
|
||||
run_kwargs: dict = {}
|
||||
stdin_mock = _make_mock_stdin()
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
exec_calls.append(args)
|
||||
proc = MagicMock()
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=0)
|
||||
proc.stdin = stdin_mock
|
||||
proc.stderr = MagicMock()
|
||||
proc.stderr.read = MagicMock()
|
||||
proc.stderr.read.aio = AsyncMock(return_value="")
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
run_kwargs.update(kwargs)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
return exec_calls, run_kwargs, stdin_mock
|
||||
|
||||
|
||||
class TestModalBulkUpload:
|
||||
"""Test _modal_bulk_upload method."""
|
||||
|
||||
def test_empty_files_is_noop(self, monkeypatch, tmp_path):
|
||||
"""Empty file list should not call worker.run_coroutine."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
env._modal_bulk_upload([])
|
||||
env._worker.run_coroutine.assert_not_called()
|
||||
|
||||
def test_tar_archive_contains_all_files(self, monkeypatch, tmp_path):
|
||||
"""The tar archive sent via stdin should contain all files."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src_a = tmp_path / "a.json"
|
||||
src_b = tmp_path / "b.py"
|
||||
src_a.write_text("cred_content")
|
||||
src_b.write_text("skill_content")
|
||||
|
||||
files = [
|
||||
(str(src_a), "/root/.hermes/credentials/a.json"),
|
||||
(str(src_b), "/root/.hermes/skills/b.py"),
|
||||
]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Verify the command reads from stdin (no echo with embedded payload)
|
||||
assert len(exec_calls) == 1
|
||||
args = exec_calls[0]
|
||||
assert args[0] == "bash"
|
||||
assert args[1] == "-c"
|
||||
cmd = args[2]
|
||||
assert "mkdir -p" in cmd
|
||||
assert "base64 -d" in cmd
|
||||
assert "tar xzf" in cmd
|
||||
assert "-C /" in cmd
|
||||
|
||||
# Reassemble the base64 payload from stdin chunks and verify tar contents
|
||||
payload = "".join(stdin_mock._written_chunks)
|
||||
tar_data = base64.b64decode(payload)
|
||||
buf = io.BytesIO(tar_data)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
names = sorted(tar.getnames())
|
||||
assert "root/.hermes/credentials/a.json" in names
|
||||
assert "root/.hermes/skills/b.py" in names
|
||||
|
||||
# Verify content
|
||||
a_content = tar.extractfile("root/.hermes/credentials/a.json").read()
|
||||
assert a_content == b"cred_content"
|
||||
b_content = tar.extractfile("root/.hermes/skills/b.py").read()
|
||||
assert b_content == b"skill_content"
|
||||
|
||||
# Verify stdin was closed
|
||||
stdin_mock.write_eof.assert_called_once()
|
||||
|
||||
def test_mkdir_includes_all_parents(self, monkeypatch, tmp_path):
|
||||
"""Remote parent directories should be pre-created in the command."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
|
||||
files = [
|
||||
(str(src), "/root/.hermes/credentials/f.txt"),
|
||||
(str(src), "/root/.hermes/skills/deep/nested/f.txt"),
|
||||
]
|
||||
|
||||
exec_calls, _, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
cmd = exec_calls[0][2]
|
||||
assert "/root/.hermes/credentials" in cmd
|
||||
assert "/root/.hermes/skills/deep/nested" in cmd
|
||||
|
||||
def test_single_exec_call(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload should use exactly one exec call regardless of file count."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
files = []
|
||||
for i in range(20):
|
||||
src = tmp_path / f"file_{i}.txt"
|
||||
src.write_text(f"content_{i}")
|
||||
files.append((str(src), f"/root/.hermes/cache/file_{i}.txt"))
|
||||
|
||||
exec_calls, _, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Should be exactly 1 exec call, not 20
|
||||
assert len(exec_calls) == 1
|
||||
|
||||
def test_bulk_upload_wired_in_filesyncmanager(self, monkeypatch):
|
||||
"""Verify ModalEnvironment passes bulk_upload_fn to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
# Create a minimal env without full __init__
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
|
||||
# Manually call the part of __init__ that wires FileSyncManager
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = modal_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||
upload_fn=env._modal_upload,
|
||||
delete_fn=env._modal_delete,
|
||||
bulk_upload_fn=env._modal_bulk_upload,
|
||||
)
|
||||
|
||||
assert "bulk_upload_fn" in captured_kwargs
|
||||
assert captured_kwargs["bulk_upload_fn"] is not None
|
||||
assert callable(captured_kwargs["bulk_upload_fn"])
|
||||
|
||||
def test_timeout_set_to_120(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload uses a 120s timeout (not the per-file 15s)."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
_, run_kwargs, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
assert run_kwargs.get("timeout") == 120
|
||||
|
||||
def test_nonzero_exit_raises(self, monkeypatch, tmp_path):
|
||||
"""Non-zero exit code from remote exec should raise RuntimeError."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
stdin_mock = _make_mock_stdin()
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
proc = MagicMock()
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=1) # non-zero exit
|
||||
proc.stdin = stdin_mock
|
||||
proc.stderr = MagicMock()
|
||||
proc.stderr.read = MagicMock()
|
||||
proc.stderr.read.aio = AsyncMock(return_value="tar: error")
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
|
||||
with pytest.raises(RuntimeError, match="Modal bulk upload failed"):
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
def test_payload_not_in_command_string(self, monkeypatch, tmp_path):
|
||||
"""The base64 payload must NOT appear in the bash -c argument.
|
||||
|
||||
This is the core ARG_MAX fix: the payload goes through stdin,
|
||||
not embedded in the command string.
|
||||
"""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("some data to upload")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# The command should NOT contain an echo with the payload
|
||||
cmd = exec_calls[0][2]
|
||||
assert "echo" not in cmd
|
||||
# The payload should go through stdin
|
||||
assert len(stdin_mock._written_chunks) > 0
|
||||
|
||||
def test_stdin_chunked_for_large_payloads(self, monkeypatch, tmp_path):
|
||||
"""Payloads larger than _STDIN_CHUNK_SIZE should be split into multiple writes."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
# Use random bytes so gzip cannot compress them -- ensures the
|
||||
# base64 payload exceeds one 1 MB chunk.
|
||||
import os as _os
|
||||
src = tmp_path / "large.bin"
|
||||
src.write_bytes(_os.urandom(1024 * 1024 + 512 * 1024))
|
||||
files = [(str(src), "/root/.hermes/large.bin")]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Should have multiple stdin write chunks
|
||||
assert len(stdin_mock._written_chunks) >= 2
|
||||
|
||||
# Reassembled payload should still decode to valid tar
|
||||
payload = "".join(stdin_mock._written_chunks)
|
||||
tar_data = base64.b64decode(payload)
|
||||
buf = io.BytesIO(tar_data)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
names = tar.getnames()
|
||||
assert "root/.hermes/large.bin" in names
|
||||
517
tests/tools/test_ssh_bulk_upload.py
Normal file
517
tests/tools/test_ssh_bulk_upload.py
Normal file
@@ -0,0 +1,517 @@
|
||||
"""Tests for SSH bulk upload via tar pipe."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import ssh as ssh_env
|
||||
from tools.environments.file_sync import quoted_mkdir_command, unique_parent_dirs
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
|
||||
|
||||
def _mock_proc(*, returncode=0, poll_return=0, communicate_return=(b"", b""),
|
||||
stderr_read=b""):
|
||||
"""Create a MagicMock mimicking subprocess.Popen for tar/ssh pipes."""
|
||||
m = MagicMock()
|
||||
m.stdout = MagicMock()
|
||||
m.returncode = returncode
|
||||
m.poll.return_value = poll_return
|
||||
m.communicate.return_value = communicate_return
|
||||
m.stderr = MagicMock()
|
||||
m.stderr.read.return_value = stderr_read
|
||||
return m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env(monkeypatch):
|
||||
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
|
||||
)
|
||||
return SSHEnvironment(host="example.com", user="testuser")
|
||||
|
||||
|
||||
class TestSSHBulkUpload:
|
||||
"""Unit tests for _ssh_bulk_upload — tar pipe mechanics."""
|
||||
|
||||
def test_empty_files_is_noop(self, mock_env):
|
||||
"""Empty file list should not spawn any subprocesses."""
|
||||
with patch.object(subprocess, "run") as mock_run, \
|
||||
patch.object(subprocess, "Popen") as mock_popen:
|
||||
mock_env._ssh_bulk_upload([])
|
||||
mock_run.assert_not_called()
|
||||
mock_popen.assert_not_called()
|
||||
|
||||
def test_mkdir_batched_into_single_call(self, mock_env, tmp_path):
|
||||
"""All parent directories should be created in one SSH call."""
|
||||
# Create test files
|
||||
f1 = tmp_path / "a.txt"
|
||||
f1.write_text("aaa")
|
||||
f2 = tmp_path / "b.txt"
|
||||
f2.write_text("bbb")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
|
||||
(str(f2), "/home/testuser/.hermes/credentials/b.txt"),
|
||||
]
|
||||
|
||||
# Mock subprocess.run for mkdir and Popen for tar pipe
|
||||
mock_run = MagicMock(return_value=subprocess.CompletedProcess([], 0))
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
m = MagicMock()
|
||||
m.stdout = MagicMock()
|
||||
m.returncode = 0
|
||||
m.poll.return_value = 0
|
||||
m.communicate.return_value = (b"", b"")
|
||||
m.stderr = MagicMock()
|
||||
m.stderr.read.return_value = b""
|
||||
return m
|
||||
|
||||
with patch.object(subprocess, "run", mock_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# Exactly one subprocess.run call for mkdir
|
||||
assert mock_run.call_count == 1
|
||||
mkdir_cmd = mock_run.call_args[0][0]
|
||||
# Should contain mkdir -p with both parent dirs
|
||||
mkdir_str = " ".join(mkdir_cmd)
|
||||
assert "mkdir -p" in mkdir_str
|
||||
assert "/home/testuser/.hermes/skills" in mkdir_str
|
||||
assert "/home/testuser/.hermes/credentials" in mkdir_str
|
||||
|
||||
def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path):
|
||||
"""Symlinks in staging dir should mirror the remote path structure."""
|
||||
f1 = tmp_path / "local_a.txt"
|
||||
f1.write_text("content a")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/my_skill.md"),
|
||||
]
|
||||
|
||||
staging_paths = []
|
||||
|
||||
def capture_tar_cmd(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
# Capture the staging dir from -C argument
|
||||
c_idx = cmd.index("-C")
|
||||
staging_dir = cmd[c_idx + 1]
|
||||
# Check the symlink exists
|
||||
expected = os.path.join(
|
||||
staging_dir, "home/testuser/.hermes/skills/my_skill.md"
|
||||
)
|
||||
staging_paths.append(expected)
|
||||
assert os.path.islink(expected), f"Expected symlink at {expected}"
|
||||
assert os.readlink(expected) == os.path.abspath(str(f1))
|
||||
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_tar_cmd):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
assert len(staging_paths) == 1, "tar command should have been called"
|
||||
|
||||
def test_tar_pipe_commands(self, mock_env, tmp_path):
|
||||
"""Verify tar and SSH commands are wired correctly."""
|
||||
f1 = tmp_path / "x.txt"
|
||||
f1.write_text("x")
|
||||
|
||||
files = [(str(f1), "/home/testuser/.hermes/cache/x.txt")]
|
||||
|
||||
popen_cmds = []
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
assert len(popen_cmds) == 2, "Should spawn tar + ssh processes"
|
||||
|
||||
tar_cmd = popen_cmds[0]
|
||||
ssh_cmd = popen_cmds[1]
|
||||
|
||||
# tar: create, dereference symlinks, to stdout
|
||||
assert tar_cmd[0] == "tar"
|
||||
assert "-chf" in tar_cmd
|
||||
assert "-" in tar_cmd # stdout
|
||||
assert "-C" in tar_cmd
|
||||
|
||||
# ssh: extract from stdin at /
|
||||
ssh_str = " ".join(ssh_cmd)
|
||||
assert "ssh" in ssh_str
|
||||
assert "tar xf - -C /" in ssh_str
|
||||
assert "testuser@example.com" in ssh_str
|
||||
|
||||
def test_mkdir_failure_raises(self, mock_env, tmp_path):
|
||||
"""mkdir failure should raise RuntimeError before tar pipe."""
|
||||
f1 = tmp_path / "y.txt"
|
||||
f1.write_text("y")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/y.txt")]
|
||||
|
||||
failed_run = subprocess.CompletedProcess([], 1, stderr="Permission denied")
|
||||
with patch.object(subprocess, "run", return_value=failed_run):
|
||||
with pytest.raises(RuntimeError, match="remote mkdir failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_tar_create_failure_raises(self, mock_env, tmp_path):
|
||||
"""tar create failure should raise RuntimeError."""
|
||||
f1 = tmp_path / "z.txt"
|
||||
f1.write_text("z")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/z.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = 1
|
||||
mock_tar.poll.return_value = 1
|
||||
mock_tar.communicate.return_value = (b"tar: error", b"")
|
||||
mock_tar.stderr = MagicMock()
|
||||
mock_tar.stderr.read.return_value = b"tar: error"
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.return_value = (b"", b"")
|
||||
mock_ssh.returncode = 0
|
||||
|
||||
def popen_side_effect(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
|
||||
with pytest.raises(RuntimeError, match="tar create failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_ssh_extract_failure_raises(self, mock_env, tmp_path):
|
||||
"""SSH tar extract failure should raise RuntimeError."""
|
||||
f1 = tmp_path / "w.txt"
|
||||
f1.write_text("w")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/w.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = 0
|
||||
mock_tar.poll.return_value = 0
|
||||
mock_tar.communicate.return_value = (b"", b"")
|
||||
mock_tar.stderr = MagicMock()
|
||||
mock_tar.stderr.read.return_value = b""
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.return_value = (b"", b"Permission denied")
|
||||
mock_ssh.returncode = 1
|
||||
|
||||
def popen_side_effect(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
|
||||
with pytest.raises(RuntimeError, match="tar extract over SSH failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_ssh_command_uses_control_socket(self, mock_env, tmp_path):
|
||||
"""SSH command for tar extract should reuse ControlMaster socket."""
|
||||
f1 = tmp_path / "c.txt"
|
||||
f1.write_text("c")
|
||||
files = [(str(f1), "/home/testuser/.hermes/cache/c.txt")]
|
||||
|
||||
popen_cmds = []
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# The SSH command (second Popen call) should include ControlPath
|
||||
ssh_cmd = popen_cmds[1]
|
||||
assert f"ControlPath={mock_env.control_socket}" in " ".join(ssh_cmd)
|
||||
|
||||
def test_custom_port_and_key_in_ssh_command(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload SSH command should include custom port and key."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
|
||||
)
|
||||
env = SSHEnvironment(host="h", user="u", port=2222, key_path="/my/key")
|
||||
|
||||
f1 = tmp_path / "d.txt"
|
||||
f1.write_text("d")
|
||||
files = [(str(f1), "/home/u/.hermes/skills/d.txt")]
|
||||
|
||||
run_cmds = []
|
||||
popen_cmds = []
|
||||
|
||||
def capture_run(cmd, **kwargs):
|
||||
run_cmds.append(cmd)
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=capture_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
env._ssh_bulk_upload(files)
|
||||
|
||||
# Check mkdir SSH call includes port and key
|
||||
assert len(run_cmds) == 1
|
||||
mkdir_cmd = run_cmds[0]
|
||||
assert "-p" in mkdir_cmd and "2222" in mkdir_cmd
|
||||
assert "-i" in mkdir_cmd and "/my/key" in mkdir_cmd
|
||||
|
||||
# Check tar extract SSH call includes port and key
|
||||
ssh_cmd = popen_cmds[1]
|
||||
assert "-p" in ssh_cmd and "2222" in ssh_cmd
|
||||
assert "-i" in ssh_cmd and "/my/key" in ssh_cmd
|
||||
|
||||
def test_parent_dirs_deduplicated(self, mock_env, tmp_path):
|
||||
"""Multiple files in the same dir should produce one mkdir entry."""
|
||||
f1 = tmp_path / "a.txt"
|
||||
f1.write_text("a")
|
||||
f2 = tmp_path / "b.txt"
|
||||
f2.write_text("b")
|
||||
f3 = tmp_path / "c.txt"
|
||||
f3.write_text("c")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
|
||||
(str(f2), "/home/testuser/.hermes/skills/b.txt"),
|
||||
(str(f3), "/home/testuser/.hermes/credentials/c.txt"),
|
||||
]
|
||||
|
||||
run_cmds = []
|
||||
|
||||
def capture_run(cmd, **kwargs):
|
||||
run_cmds.append(cmd)
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
def make_mock_proc(cmd, **kwargs):
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=capture_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_mock_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# Only one mkdir call
|
||||
assert len(run_cmds) == 1
|
||||
mkdir_str = " ".join(run_cmds[0])
|
||||
# skills dir should appear exactly once despite two files
|
||||
assert mkdir_str.count("/home/testuser/.hermes/skills") == 1
|
||||
assert "/home/testuser/.hermes/credentials" in mkdir_str
|
||||
|
||||
def test_tar_stdout_closed_for_sigpipe(self, mock_env, tmp_path):
|
||||
"""tar_proc.stdout must be closed so SIGPIPE propagates correctly."""
|
||||
f1 = tmp_path / "s.txt"
|
||||
f1.write_text("s")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/s.txt")]
|
||||
|
||||
mock_tar_stdout = MagicMock()
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
mock = MagicMock()
|
||||
if cmd[0] == "tar":
|
||||
mock.stdout = mock_tar_stdout
|
||||
else:
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar_stdout.close.assert_called_once()
|
||||
|
||||
def test_timeout_kills_both_processes(self, mock_env, tmp_path):
|
||||
"""TimeoutExpired during communicate should kill both processes."""
|
||||
f1 = tmp_path / "t.txt"
|
||||
f1.write_text("t")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/t.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = None
|
||||
mock_tar.poll.return_value = None
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.side_effect = subprocess.TimeoutExpired("ssh", 120)
|
||||
mock_ssh.returncode = None
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
with pytest.raises(RuntimeError, match="SSH bulk upload timed out"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar.kill.assert_called_once()
|
||||
mock_ssh.kill.assert_called_once()
|
||||
|
||||
|
||||
class TestSSHBulkUploadWiring:
|
||||
"""Verify bulk_upload_fn is wired into FileSyncManager."""
|
||||
|
||||
def test_filesyncmanager_receives_bulk_upload_fn(self, monkeypatch):
|
||||
"""SSHEnvironment should pass _ssh_bulk_upload to FileSyncManager."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
class FakeSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", FakeSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
|
||||
assert "bulk_upload_fn" in captured_kwargs
|
||||
assert captured_kwargs["bulk_upload_fn"] is not None
|
||||
# Should be the bound method
|
||||
assert callable(captured_kwargs["bulk_upload_fn"])
|
||||
|
||||
|
||||
class TestSharedHelpers:
|
||||
"""Direct unit tests for file_sync.py helpers."""
|
||||
|
||||
def test_quoted_mkdir_command_basic(self):
|
||||
result = quoted_mkdir_command(["/a", "/b/c"])
|
||||
assert result == "mkdir -p /a /b/c"
|
||||
|
||||
def test_quoted_mkdir_command_quotes_special_chars(self):
|
||||
result = quoted_mkdir_command(["/path/with spaces", "/path/'quotes'"])
|
||||
assert "mkdir -p" in result
|
||||
# shlex.quote wraps in single quotes
|
||||
assert "'/path/with spaces'" in result
|
||||
|
||||
def test_quoted_mkdir_command_empty(self):
|
||||
result = quoted_mkdir_command([])
|
||||
assert result == "mkdir -p "
|
||||
|
||||
def test_unique_parent_dirs_deduplicates(self):
|
||||
files = [
|
||||
("/local/a.txt", "/remote/dir/a.txt"),
|
||||
("/local/b.txt", "/remote/dir/b.txt"),
|
||||
("/local/c.txt", "/remote/other/c.txt"),
|
||||
]
|
||||
result = unique_parent_dirs(files)
|
||||
assert result == ["/remote/dir", "/remote/other"]
|
||||
|
||||
def test_unique_parent_dirs_sorted(self):
|
||||
files = [
|
||||
("/local/z.txt", "/z/file.txt"),
|
||||
("/local/a.txt", "/a/file.txt"),
|
||||
]
|
||||
result = unique_parent_dirs(files)
|
||||
assert result == ["/a", "/z"]
|
||||
|
||||
def test_unique_parent_dirs_empty(self):
|
||||
assert unique_parent_dirs([]) == []
|
||||
|
||||
|
||||
class TestSSHBulkUploadEdgeCases:
|
||||
"""Edge cases for _ssh_bulk_upload."""
|
||||
|
||||
def test_ssh_popen_failure_kills_tar(self, mock_env, tmp_path):
|
||||
"""If SSH Popen raises, tar process must be killed and cleaned up."""
|
||||
f1 = tmp_path / "e.txt"
|
||||
f1.write_text("e")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/e.txt")]
|
||||
|
||||
mock_tar = _mock_proc()
|
||||
|
||||
call_count = 0
|
||||
|
||||
def failing_ssh_popen(cmd, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return mock_tar # tar Popen succeeds
|
||||
raise OSError("SSH binary not found")
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=failing_ssh_popen):
|
||||
with pytest.raises(OSError, match="SSH binary not found"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar.kill.assert_called_once()
|
||||
mock_tar.wait.assert_called_once()
|
||||
485
tests/tools/test_sync_back_backends.py
Normal file
485
tests/tools/test_sync_back_backends.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Tests for backend-specific bulk download implementations and cleanup() wiring."""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import ssh as ssh_env
|
||||
from tools.environments import modal as modal_env
|
||||
from tools.environments import daytona as daytona_env
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
|
||||
|
||||
# ── SSH helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ssh_mock_env(monkeypatch):
|
||||
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {
|
||||
"sync": lambda self, **k: None,
|
||||
"sync_back": lambda self: None,
|
||||
})(),
|
||||
)
|
||||
return SSHEnvironment(host="example.com", user="testuser")
|
||||
|
||||
|
||||
# ── Modal helpers ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_mock_modal_env():
|
||||
"""Create a minimal ModalEnvironment without calling __init__."""
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
env._sync_manager = None
|
||||
return env
|
||||
|
||||
|
||||
def _wire_modal_download(env, *, tar_bytes=b"fake-tar-data", exit_code=0):
|
||||
"""Wire sandbox.exec.aio to return mock tar output for download tests.
|
||||
|
||||
Returns the exec_calls list for assertion.
|
||||
"""
|
||||
exec_calls = []
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
exec_calls.append(args)
|
||||
proc = MagicMock()
|
||||
proc.stdout = MagicMock()
|
||||
proc.stdout.read = MagicMock()
|
||||
proc.stdout.read.aio = AsyncMock(return_value=tar_bytes)
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=exit_code)
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
return exec_calls
|
||||
|
||||
|
||||
# ── Daytona helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_mock_daytona_env():
|
||||
"""Create a minimal DaytonaEnvironment without calling __init__."""
|
||||
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._remote_home = "/root"
|
||||
env._sync_manager = None
|
||||
env._lock = __import__("threading").Lock()
|
||||
env._persistent = True
|
||||
env._task_id = "test"
|
||||
env._daytona = MagicMock()
|
||||
return env
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# SSH bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestSSHBulkDownload:
|
||||
"""Unit tests for _ssh_bulk_download."""
|
||||
|
||||
def test_ssh_bulk_download_runs_tar_over_ssh(self, ssh_mock_env, tmp_path):
|
||||
"""subprocess.run command should include tar cf - over SSH."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
# open() will be called to write stdout; mock it to avoid actual file I/O
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
mock_run.assert_called_once()
|
||||
cmd = mock_run.call_args[0][0]
|
||||
cmd_str = " ".join(cmd)
|
||||
assert "tar cf -" in cmd_str
|
||||
assert "-C /" in cmd_str
|
||||
assert "home/testuser/.hermes" in cmd_str
|
||||
assert "ssh" in cmd_str
|
||||
assert "testuser@example.com" in cmd_str
|
||||
|
||||
def test_ssh_bulk_download_writes_to_dest(self, ssh_mock_env, tmp_path):
|
||||
"""subprocess.run should receive stdout=open(dest, 'wb')."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
# The stdout kwarg should be a file object opened for writing
|
||||
call_kwargs = mock_run.call_args
|
||||
# stdout is passed as a keyword arg
|
||||
stdout_val = call_kwargs.kwargs.get("stdout") or call_kwargs[1].get("stdout")
|
||||
# The file was opened via `with open(dest, "wb") as f` and passed as stdout=f.
|
||||
# After the context manager exits, the file is closed, but we can verify
|
||||
# the dest path was used by checking if the file was created.
|
||||
assert dest.exists()
|
||||
|
||||
def test_ssh_bulk_download_raises_on_failure(self, ssh_mock_env, tmp_path):
|
||||
"""Non-zero returncode should raise RuntimeError."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
failed = subprocess.CompletedProcess([], 1, stderr=b"Permission denied")
|
||||
with patch.object(subprocess, "run", return_value=failed):
|
||||
with pytest.raises(RuntimeError, match="SSH bulk download failed"):
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
def test_ssh_bulk_download_uses_120s_timeout(self, ssh_mock_env, tmp_path):
|
||||
"""The subprocess.run call should use a 120s timeout."""
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with patch.object(subprocess, "run", return_value=subprocess.CompletedProcess([], 0)) as mock_run:
|
||||
ssh_mock_env._ssh_bulk_download(dest)
|
||||
|
||||
call_kwargs = mock_run.call_args
|
||||
assert call_kwargs.kwargs.get("timeout") == 120 or call_kwargs[1].get("timeout") == 120
|
||||
|
||||
|
||||
class TestSSHCleanup:
|
||||
"""Verify SSH cleanup() calls sync_back() before closing ControlMaster."""
|
||||
|
||||
def test_ssh_cleanup_calls_sync_back(self, monkeypatch):
|
||||
"""cleanup() should call sync_back() before SSH control socket teardown."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
call_order = []
|
||||
|
||||
class TrackingSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
def sync_back(self):
|
||||
call_order.append("sync_back")
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
# Ensure control_socket does not exist so cleanup skips the SSH exit call
|
||||
env.control_socket = Path("/nonexistent/socket")
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
|
||||
def test_ssh_cleanup_calls_sync_back_before_control_exit(self, monkeypatch):
|
||||
"""sync_back() must run before the ControlMaster exit command."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
call_order = []
|
||||
|
||||
class TrackingSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
def sync_back(self):
|
||||
call_order.append("sync_back")
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", TrackingSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
|
||||
# Create a fake control socket so cleanup tries the SSH exit
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".sock") as tmp:
|
||||
env.control_socket = Path(tmp.name)
|
||||
|
||||
def mock_run(cmd, **kwargs):
|
||||
cmd_str = " ".join(cmd)
|
||||
if "-O" in cmd and "exit" in cmd_str:
|
||||
call_order.append("control_exit")
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=mock_run):
|
||||
env.cleanup()
|
||||
|
||||
assert call_order.index("sync_back") < call_order.index("control_exit")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Modal bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestModalBulkDownload:
|
||||
"""Unit tests for _modal_bulk_download."""
|
||||
|
||||
def test_modal_bulk_download_command(self, tmp_path):
|
||||
"""exec should be called with tar cf - -C /root/.hermes ."""
|
||||
env = _make_mock_modal_env()
|
||||
exec_calls = _wire_modal_download(env, tar_bytes=b"tar-content")
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert len(exec_calls) == 1
|
||||
args = exec_calls[0]
|
||||
assert args[0] == "bash"
|
||||
assert args[1] == "-c"
|
||||
assert "tar cf -" in args[2]
|
||||
assert "-C / root/.hermes" in args[2]
|
||||
|
||||
def test_modal_bulk_download_writes_to_dest(self, tmp_path):
|
||||
"""Downloaded tar bytes should be written to the dest path."""
|
||||
env = _make_mock_modal_env()
|
||||
expected_data = b"some-tar-archive-bytes"
|
||||
_wire_modal_download(env, tar_bytes=expected_data)
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert dest.exists()
|
||||
assert dest.read_bytes() == expected_data
|
||||
|
||||
def test_modal_bulk_download_handles_str_output(self, tmp_path):
|
||||
"""If stdout returns str instead of bytes, it should be encoded."""
|
||||
env = _make_mock_modal_env()
|
||||
# Simulate Modal SDK returning str
|
||||
_wire_modal_download(env, tar_bytes="string-tar-data")
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert dest.read_bytes() == b"string-tar-data"
|
||||
|
||||
def test_modal_bulk_download_raises_on_failure(self, tmp_path):
|
||||
"""Non-zero exit code should raise RuntimeError."""
|
||||
env = _make_mock_modal_env()
|
||||
_wire_modal_download(env, exit_code=1)
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
with pytest.raises(RuntimeError, match="Modal bulk download failed"):
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
def test_modal_bulk_download_uses_120s_timeout(self, tmp_path):
|
||||
"""run_coroutine should be called with timeout=120."""
|
||||
env = _make_mock_modal_env()
|
||||
_wire_modal_download(env, tar_bytes=b"data")
|
||||
|
||||
run_kwargs = {}
|
||||
original_run = env._worker.run_coroutine
|
||||
|
||||
def tracking_run(coro, **kwargs):
|
||||
run_kwargs.update(kwargs)
|
||||
return original_run(coro, **kwargs)
|
||||
|
||||
env._worker.run_coroutine = tracking_run
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._modal_bulk_download(dest)
|
||||
|
||||
assert run_kwargs.get("timeout") == 120
|
||||
|
||||
|
||||
class TestModalCleanup:
|
||||
"""Verify Modal cleanup() calls sync_back() before terminate."""
|
||||
|
||||
def test_modal_cleanup_calls_sync_back(self):
|
||||
"""cleanup() should call sync_back() before sandbox.terminate."""
|
||||
env = _make_mock_modal_env()
|
||||
|
||||
call_order = []
|
||||
sync_mgr = MagicMock()
|
||||
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||
env._sync_manager = sync_mgr
|
||||
|
||||
# Mock terminate to track call order
|
||||
async def mock_terminate():
|
||||
pass
|
||||
|
||||
env._sandbox.terminate = MagicMock()
|
||||
env._sandbox.terminate.aio = mock_terminate
|
||||
env._worker.run_coroutine = lambda coro, **kw: (
|
||||
call_order.append("terminate"),
|
||||
asyncio.new_event_loop().run_until_complete(coro),
|
||||
)
|
||||
env._worker.stop = lambda: None
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
assert call_order.index("sync_back") < call_order.index("terminate")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Daytona bulk download
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestDaytonaBulkDownload:
|
||||
"""Unit tests for _daytona_bulk_download."""
|
||||
|
||||
def test_daytona_bulk_download_creates_tar_and_downloads(self, tmp_path):
|
||||
"""exec and download_file should both be called."""
|
||||
env = _make_mock_daytona_env()
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._daytona_bulk_download(dest)
|
||||
|
||||
env._sandbox.process.exec.assert_called_once()
|
||||
exec_cmd = env._sandbox.process.exec.call_args[0][0]
|
||||
assert "tar cf" in exec_cmd
|
||||
assert "/tmp/.hermes_sync.tar" in exec_cmd
|
||||
assert ".hermes" in exec_cmd
|
||||
|
||||
env._sandbox.fs.download_file.assert_called_once_with(
|
||||
"/tmp/.hermes_sync.tar", str(dest)
|
||||
)
|
||||
|
||||
def test_daytona_bulk_download_uses_remote_home(self, tmp_path):
|
||||
"""The tar command should use the env's _remote_home."""
|
||||
env = _make_mock_daytona_env()
|
||||
env._remote_home = "/home/daytona"
|
||||
dest = tmp_path / "backup.tar"
|
||||
|
||||
env._daytona_bulk_download(dest)
|
||||
|
||||
exec_cmd = env._sandbox.process.exec.call_args[0][0]
|
||||
assert "home/daytona/.hermes" in exec_cmd
|
||||
|
||||
|
||||
class TestDaytonaCleanup:
|
||||
"""Verify Daytona cleanup() calls sync_back() before stop."""
|
||||
|
||||
def test_daytona_cleanup_calls_sync_back(self):
|
||||
"""cleanup() should call sync_back() before sandbox.stop()."""
|
||||
env = _make_mock_daytona_env()
|
||||
|
||||
call_order = []
|
||||
sync_mgr = MagicMock()
|
||||
sync_mgr.sync_back = lambda: call_order.append("sync_back")
|
||||
env._sync_manager = sync_mgr
|
||||
env._sandbox.stop = lambda: call_order.append("stop")
|
||||
|
||||
env.cleanup()
|
||||
|
||||
assert "sync_back" in call_order
|
||||
assert "stop" in call_order
|
||||
assert call_order.index("sync_back") < call_order.index("stop")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# FileSyncManager wiring: bulk_download_fn passed by each backend
|
||||
# =====================================================================
|
||||
|
||||
|
||||
class TestBulkDownloadWiring:
|
||||
"""Verify each backend passes bulk_download_fn to FileSyncManager."""
|
||||
|
||||
def test_ssh_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""SSHEnvironment should pass _ssh_bulk_download to FileSyncManager."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
class CaptureSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", CaptureSyncManager)
|
||||
|
||||
SSHEnvironment(host="h", user="u")
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
|
||||
def test_modal_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""ModalEnvironment should pass _modal_bulk_download to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
|
||||
# Replicate the wiring done in __init__
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = modal_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||
upload_fn=env._modal_upload,
|
||||
delete_fn=env._modal_delete,
|
||||
bulk_upload_fn=env._modal_bulk_upload,
|
||||
bulk_download_fn=env._modal_bulk_download,
|
||||
)
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
|
||||
def test_daytona_passes_bulk_download_fn(self, monkeypatch):
|
||||
"""DaytonaEnvironment should pass _daytona_bulk_download to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(daytona_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
env = object.__new__(daytona_env.DaytonaEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._remote_home = "/root"
|
||||
env._lock = __import__("threading").Lock()
|
||||
env._persistent = True
|
||||
env._task_id = "test"
|
||||
env._daytona = MagicMock()
|
||||
|
||||
# Replicate the wiring done in __init__
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = daytona_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files(f"{env._remote_home}/.hermes"),
|
||||
upload_fn=env._daytona_upload,
|
||||
delete_fn=env._daytona_delete,
|
||||
bulk_upload_fn=env._daytona_bulk_upload,
|
||||
bulk_download_fn=env._daytona_bulk_download,
|
||||
)
|
||||
|
||||
assert "bulk_download_fn" in captured_kwargs
|
||||
assert callable(captured_kwargs["bulk_download_fn"])
|
||||
@@ -15,7 +15,14 @@ from tools.environments.base import (
|
||||
BaseEnvironment,
|
||||
_ThreadedProcessHandle,
|
||||
)
|
||||
from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command
|
||||
from tools.environments.file_sync import (
|
||||
BulkDownloadFn,
|
||||
FileSyncManager,
|
||||
iter_sync_files,
|
||||
quoted_mkdir_command,
|
||||
quoted_rm_command,
|
||||
unique_parent_dirs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -128,6 +135,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
upload_fn=self._daytona_upload,
|
||||
delete_fn=self._daytona_delete,
|
||||
bulk_upload_fn=self._daytona_bulk_upload,
|
||||
bulk_download_fn=self._daytona_bulk_download,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
self.init_session()
|
||||
@@ -150,11 +158,9 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
if not files:
|
||||
return
|
||||
|
||||
# Pre-create all unique parent directories in one shell call
|
||||
parents = sorted({str(Path(remote).parent) for _, remote in files})
|
||||
parents = unique_parent_dirs(files)
|
||||
if parents:
|
||||
mkdir_cmd = "mkdir -p " + " ".join(shlex.quote(p) for p in parents)
|
||||
self._sandbox.process.exec(mkdir_cmd)
|
||||
self._sandbox.process.exec(quoted_mkdir_command(parents))
|
||||
|
||||
uploads = [
|
||||
FileUpload(source=host_path, destination=remote_path)
|
||||
@@ -162,6 +168,14 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
]
|
||||
self._sandbox.fs.upload_files(uploads)
|
||||
|
||||
def _daytona_bulk_download(self, dest: Path) -> None:
|
||||
"""Download remote .hermes/ as a tar archive."""
|
||||
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
||||
self._sandbox.process.exec(
|
||||
f"tar cf /tmp/.hermes_sync.tar -C / {shlex.quote(rel_base)}"
|
||||
)
|
||||
self._sandbox.fs.download_file("/tmp/.hermes_sync.tar", str(dest))
|
||||
|
||||
def _daytona_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files via SDK exec."""
|
||||
self._sandbox.process.exec(quoted_rm_command(remote_paths))
|
||||
@@ -209,6 +223,10 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
|
||||
def cleanup(self):
|
||||
if self._sync_manager:
|
||||
logger.info("Daytona: syncing files from sandbox...")
|
||||
self._sync_manager.sync_back()
|
||||
|
||||
with self._lock:
|
||||
if self._sandbox is None:
|
||||
return
|
||||
|
||||
@@ -6,12 +6,21 @@ and Daytona. Docker and Singularity use bind mounts (live host FS
|
||||
view) and don't need this.
|
||||
"""
|
||||
|
||||
import fcntl
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import signal
|
||||
import tarfile
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.environments.base import _file_mtime_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,6 +31,7 @@ _FORCE_SYNC_ENV = "HERMES_FORCE_FILE_SYNC"
|
||||
# Transport callbacks provided by each backend
|
||||
UploadFn = Callable[[str, str], None] # (host_path, remote_path) -> raises on failure
|
||||
BulkUploadFn = Callable[[list[tuple[str, str]]], None] # [(host_path, remote_path), ...] -> raises on failure
|
||||
BulkDownloadFn = Callable[[Path], None] # (dest_tar_path) -> writes tar archive, raises on failure
|
||||
DeleteFn = Callable[[list[str]], None] # (remote_paths) -> raises on failure
|
||||
GetFilesFn = Callable[[], list[tuple[str, str]]] # () -> [(host_path, remote_path), ...]
|
||||
|
||||
@@ -60,6 +70,29 @@ def quoted_rm_command(remote_paths: list[str]) -> str:
|
||||
return "rm -f " + " ".join(shlex.quote(p) for p in remote_paths)
|
||||
|
||||
|
||||
def quoted_mkdir_command(dirs: list[str]) -> str:
|
||||
"""Build a shell ``mkdir -p`` command for a batch of directories."""
|
||||
return "mkdir -p " + " ".join(shlex.quote(d) for d in dirs)
|
||||
|
||||
|
||||
def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]:
|
||||
"""Extract sorted unique parent directories from (host, remote) pairs."""
|
||||
return sorted({str(Path(remote).parent) for _, remote in files})
|
||||
|
||||
|
||||
def _sha256_file(path: str) -> str:
|
||||
"""Return hex SHA-256 digest of a file."""
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
_SYNC_BACK_MAX_RETRIES = 3
|
||||
_SYNC_BACK_BACKOFF = (2, 4, 8) # seconds between retries
|
||||
|
||||
|
||||
class FileSyncManager:
|
||||
"""Tracks local file changes and syncs to a remote environment.
|
||||
|
||||
@@ -78,12 +111,15 @@ class FileSyncManager:
|
||||
delete_fn: DeleteFn,
|
||||
sync_interval: float = _SYNC_INTERVAL_SECONDS,
|
||||
bulk_upload_fn: BulkUploadFn | None = None,
|
||||
bulk_download_fn: BulkDownloadFn | None = None,
|
||||
):
|
||||
self._get_files_fn = get_files_fn
|
||||
self._upload_fn = upload_fn
|
||||
self._bulk_upload_fn = bulk_upload_fn
|
||||
self._bulk_download_fn = bulk_download_fn
|
||||
self._delete_fn = delete_fn
|
||||
self._synced_files: dict[str, tuple[float, int]] = {} # remote_path -> (mtime, size)
|
||||
self._pushed_hashes: dict[str, str] = {} # remote_path -> sha256 hex digest
|
||||
self._last_sync_time: float = 0.0 # monotonic; 0 ensures first sync runs
|
||||
self._sync_interval = sync_interval
|
||||
|
||||
@@ -125,6 +161,7 @@ class FileSyncManager:
|
||||
|
||||
# Snapshot for rollback (only when there's work to do)
|
||||
prev_files = dict(self._synced_files)
|
||||
prev_hashes = dict(self._pushed_hashes)
|
||||
|
||||
if to_upload:
|
||||
logger.debug("file_sync: uploading %d file(s)", len(to_upload))
|
||||
@@ -145,13 +182,176 @@ class FileSyncManager:
|
||||
logger.debug("file_sync: deleted %s", to_delete)
|
||||
|
||||
# --- Commit (all succeeded) ---
|
||||
for host_path, remote_path in to_upload:
|
||||
self._pushed_hashes[remote_path] = _sha256_file(host_path)
|
||||
|
||||
for p in to_delete:
|
||||
new_files.pop(p, None)
|
||||
self._pushed_hashes.pop(p, None)
|
||||
|
||||
self._synced_files = new_files
|
||||
self._last_sync_time = time.monotonic()
|
||||
|
||||
except Exception as exc:
|
||||
self._synced_files = prev_files
|
||||
self._pushed_hashes = prev_hashes
|
||||
self._last_sync_time = time.monotonic()
|
||||
logger.warning("file_sync: sync failed, rolled back state: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sync-back: pull remote changes to host on teardown
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def sync_back(self, hermes_home: Path | None = None) -> None:
|
||||
"""Pull remote changes back to the host filesystem.
|
||||
|
||||
Downloads the remote ``.hermes/`` directory as a tar archive,
|
||||
unpacks it, and applies only files that differ from what was
|
||||
originally pushed (based on SHA-256 content hashes).
|
||||
|
||||
Protected against SIGINT (defers the signal until complete) and
|
||||
serialized across concurrent gateway sandboxes via file lock.
|
||||
"""
|
||||
if self._bulk_download_fn is None:
|
||||
return
|
||||
|
||||
lock_path = (hermes_home or get_hermes_home()) / ".sync.lock"
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_SYNC_BACK_MAX_RETRIES):
|
||||
try:
|
||||
self._sync_back_once(lock_path)
|
||||
return
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt < _SYNC_BACK_MAX_RETRIES - 1:
|
||||
delay = _SYNC_BACK_BACKOFF[attempt]
|
||||
logger.warning(
|
||||
"sync_back: attempt %d failed (%s), retrying in %ds",
|
||||
attempt + 1, exc, delay,
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
logger.warning("sync_back: all %d attempts failed: %s", _SYNC_BACK_MAX_RETRIES, last_exc)
|
||||
|
||||
def _sync_back_once(self, lock_path: Path) -> None:
|
||||
"""Single sync-back attempt with SIGINT protection and file lock."""
|
||||
# signal.signal() only works from the main thread. In gateway
|
||||
# contexts cleanup() may run from a worker thread — skip SIGINT
|
||||
# deferral there rather than crashing.
|
||||
on_main_thread = threading.current_thread() is threading.main_thread()
|
||||
|
||||
deferred_sigint: list[object] = []
|
||||
original_handler = None
|
||||
if on_main_thread:
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def _defer_sigint(signum, frame):
|
||||
deferred_sigint.append((signum, frame))
|
||||
logger.debug("sync_back: SIGINT deferred until sync completes")
|
||||
|
||||
signal.signal(signal.SIGINT, _defer_sigint)
|
||||
try:
|
||||
self._sync_back_locked(lock_path)
|
||||
finally:
|
||||
if on_main_thread and original_handler is not None:
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
if deferred_sigint:
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
def _sync_back_locked(self, lock_path: Path) -> None:
|
||||
"""Sync-back under file lock (serializes concurrent gateways)."""
|
||||
lock_fd = open(lock_path, "w")
|
||||
try:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_EX)
|
||||
self._sync_back_impl()
|
||||
finally:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_UN)
|
||||
lock_fd.close()
|
||||
|
||||
def _sync_back_impl(self) -> None:
|
||||
"""Download, diff, and apply remote changes to host."""
|
||||
assert self._bulk_download_fn is not None
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".tar") as tf:
|
||||
self._bulk_download_fn(Path(tf.name))
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="hermes-sync-back-") as staging:
|
||||
with tarfile.open(tf.name) as tar:
|
||||
tar.extractall(staging, filter="data")
|
||||
|
||||
applied = 0
|
||||
for dirpath, _dirnames, filenames in os.walk(staging):
|
||||
for fname in filenames:
|
||||
staged_file = os.path.join(dirpath, fname)
|
||||
rel = os.path.relpath(staged_file, staging)
|
||||
remote_path = "/" + rel
|
||||
|
||||
remote_hash = _sha256_file(staged_file)
|
||||
pushed_hash = self._pushed_hashes.get(remote_path)
|
||||
|
||||
if remote_hash == pushed_hash:
|
||||
continue
|
||||
|
||||
# Resolve host path from get_files_fn mapping
|
||||
host_path = self._resolve_host_path(remote_path)
|
||||
if host_path is None:
|
||||
# New file created on remote — find host base
|
||||
# by mapping from remote prefix to host prefix.
|
||||
host_path = self._infer_host_path(remote_path)
|
||||
if host_path is None:
|
||||
logger.debug(
|
||||
"sync_back: skipping %s (no host mapping)",
|
||||
remote_path,
|
||||
)
|
||||
continue
|
||||
|
||||
if os.path.exists(host_path) and pushed_hash is not None:
|
||||
host_hash = _sha256_file(host_path)
|
||||
if host_hash != pushed_hash:
|
||||
logger.warning(
|
||||
"sync_back: conflict on %s — host modified "
|
||||
"since push, remote also changed. Applying "
|
||||
"remote version (last-write-wins).",
|
||||
remote_path,
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(host_path), exist_ok=True)
|
||||
shutil.copy2(staged_file, host_path)
|
||||
applied += 1
|
||||
|
||||
if applied:
|
||||
logger.info("sync_back: applied %d changed file(s)", applied)
|
||||
else:
|
||||
logger.debug("sync_back: no remote changes detected")
|
||||
|
||||
def _resolve_host_path(self, remote_path: str) -> str | None:
|
||||
"""Find the host path for a known remote path from the file mapping."""
|
||||
try:
|
||||
for host, remote in self._get_files_fn():
|
||||
if remote == remote_path:
|
||||
return host
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _infer_host_path(self, remote_path: str) -> str | None:
|
||||
"""Infer a host path for a new remote file by matching path prefixes.
|
||||
|
||||
Uses the existing file mapping to find a remote->host directory
|
||||
pair, then applies the same prefix substitution to the new file.
|
||||
For example, if the mapping has ``/root/.hermes/skills/a.md`` →
|
||||
``~/.hermes/skills/a.md``, a new remote file at
|
||||
``/root/.hermes/skills/b.md`` maps to ``~/.hermes/skills/b.md``.
|
||||
"""
|
||||
try:
|
||||
for host, remote in self._get_files_fn():
|
||||
remote_dir = str(Path(remote).parent)
|
||||
if remote_path.startswith(remote_dir + "/"):
|
||||
host_dir = str(Path(host).parent)
|
||||
suffix = remote_path[len(remote_dir):]
|
||||
return host_dir + suffix
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
@@ -5,8 +5,11 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import shlex
|
||||
import tarfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
@@ -18,7 +21,14 @@ from tools.environments.base import (
|
||||
_load_json_store,
|
||||
_save_json_store,
|
||||
)
|
||||
from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command
|
||||
from tools.environments.file_sync import (
|
||||
BulkDownloadFn,
|
||||
FileSyncManager,
|
||||
iter_sync_files,
|
||||
quoted_mkdir_command,
|
||||
quoted_rm_command,
|
||||
unique_parent_dirs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -259,26 +269,102 @@ class ModalEnvironment(BaseEnvironment):
|
||||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||
upload_fn=self._modal_upload,
|
||||
delete_fn=self._modal_delete,
|
||||
bulk_upload_fn=self._modal_bulk_upload,
|
||||
bulk_download_fn=self._modal_bulk_download,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
self.init_session()
|
||||
|
||||
def _modal_upload(self, host_path: str, remote_path: str) -> None:
|
||||
"""Upload a single file via base64-over-exec."""
|
||||
import base64
|
||||
"""Upload a single file via base64 piped through stdin."""
|
||||
content = Path(host_path).read_bytes()
|
||||
b64 = base64.b64encode(content).decode("ascii")
|
||||
container_dir = str(Path(remote_path).parent)
|
||||
cmd = (
|
||||
f"mkdir -p {shlex.quote(container_dir)} && "
|
||||
f"echo {shlex.quote(b64)} | base64 -d > {shlex.quote(remote_path)}"
|
||||
f"base64 -d > {shlex.quote(remote_path)}"
|
||||
)
|
||||
|
||||
async def _write():
|
||||
proc = await self._sandbox.exec.aio("bash", "-c", cmd)
|
||||
offset = 0
|
||||
chunk_size = self._STDIN_CHUNK_SIZE
|
||||
while offset < len(b64):
|
||||
proc.stdin.write(b64[offset:offset + chunk_size])
|
||||
await proc.stdin.drain.aio()
|
||||
offset += chunk_size
|
||||
proc.stdin.write_eof()
|
||||
await proc.stdin.drain.aio()
|
||||
await proc.wait.aio()
|
||||
|
||||
self._worker.run_coroutine(_write(), timeout=15)
|
||||
self._worker.run_coroutine(_write(), timeout=30)
|
||||
|
||||
# Modal SDK stdin buffer limit (legacy server path). The command-router
|
||||
# path allows 16 MB, but we must stay under the smaller 2 MB cap for
|
||||
# compatibility. Chunks are written below this threshold and flushed
|
||||
# individually via drain().
|
||||
_STDIN_CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB — safe for both transport paths
|
||||
|
||||
def _modal_bulk_upload(self, files: list[tuple[str, str]]) -> None:
|
||||
"""Upload many files via tar archive piped through stdin.
|
||||
|
||||
Builds a gzipped tar archive in memory and streams it into a
|
||||
``base64 -d | tar xzf -`` pipeline via the process's stdin,
|
||||
avoiding the Modal SDK's 64 KB ``ARG_MAX_BYTES`` exec-arg limit.
|
||||
"""
|
||||
if not files:
|
||||
return
|
||||
|
||||
buf = io.BytesIO()
|
||||
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||
for host_path, remote_path in files:
|
||||
tar.add(host_path, arcname=remote_path.lstrip("/"))
|
||||
payload = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
parents = unique_parent_dirs(files)
|
||||
mkdir_part = quoted_mkdir_command(parents)
|
||||
cmd = f"{mkdir_part} && base64 -d | tar xzf - -C /"
|
||||
|
||||
async def _bulk():
|
||||
proc = await self._sandbox.exec.aio("bash", "-c", cmd)
|
||||
|
||||
# Stream payload through stdin in chunks to stay under the
|
||||
# SDK's per-write buffer limit (2 MB legacy / 16 MB router).
|
||||
offset = 0
|
||||
chunk_size = self._STDIN_CHUNK_SIZE
|
||||
while offset < len(payload):
|
||||
proc.stdin.write(payload[offset:offset + chunk_size])
|
||||
await proc.stdin.drain.aio()
|
||||
offset += chunk_size
|
||||
|
||||
proc.stdin.write_eof()
|
||||
await proc.stdin.drain.aio()
|
||||
|
||||
exit_code = await proc.wait.aio()
|
||||
if exit_code != 0:
|
||||
stderr_text = await proc.stderr.read.aio()
|
||||
raise RuntimeError(
|
||||
f"Modal bulk upload failed (exit {exit_code}): {stderr_text}"
|
||||
)
|
||||
|
||||
self._worker.run_coroutine(_bulk(), timeout=120)
|
||||
|
||||
def _modal_bulk_download(self, dest: Path) -> None:
|
||||
"""Download remote .hermes/ as a tar archive."""
|
||||
async def _download():
|
||||
proc = await self._sandbox.exec.aio(
|
||||
"bash", "-c", "tar cf - -C / root/.hermes"
|
||||
)
|
||||
data = await proc.stdout.read.aio()
|
||||
exit_code = await proc.wait.aio()
|
||||
if exit_code != 0:
|
||||
raise RuntimeError(f"Modal bulk download failed (exit {exit_code})")
|
||||
return data
|
||||
|
||||
tar_bytes = self._worker.run_coroutine(_download(), timeout=120)
|
||||
if isinstance(tar_bytes, str):
|
||||
tar_bytes = tar_bytes.encode()
|
||||
dest.write_bytes(tar_bytes)
|
||||
|
||||
def _modal_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files via exec."""
|
||||
@@ -337,6 +423,10 @@ class ModalEnvironment(BaseEnvironment):
|
||||
if self._sandbox is None:
|
||||
return
|
||||
|
||||
if self._sync_manager:
|
||||
logger.info("Modal: syncing files from sandbox...")
|
||||
self._sync_manager.sync_back()
|
||||
|
||||
if self._persistent:
|
||||
try:
|
||||
async def _snapshot():
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""SSH remote execution environment with ControlMaster connection persistence."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -8,7 +9,14 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from tools.environments.base import BaseEnvironment, _popen_bash
|
||||
from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command
|
||||
from tools.environments.file_sync import (
|
||||
BulkDownloadFn,
|
||||
FileSyncManager,
|
||||
iter_sync_files,
|
||||
quoted_mkdir_command,
|
||||
quoted_rm_command,
|
||||
unique_parent_dirs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,6 +58,8 @@ class SSHEnvironment(BaseEnvironment):
|
||||
get_files_fn=lambda: iter_sync_files(f"{self._remote_home}/.hermes"),
|
||||
upload_fn=self._scp_upload,
|
||||
delete_fn=self._ssh_delete,
|
||||
bulk_upload_fn=self._ssh_bulk_upload,
|
||||
bulk_download_fn=self._ssh_bulk_download,
|
||||
)
|
||||
self._sync_manager.sync(force=True)
|
||||
|
||||
@@ -107,9 +117,8 @@ class SSHEnvironment(BaseEnvironment):
|
||||
"""Create base ~/.hermes directory tree on remote in one SSH call."""
|
||||
base = f"{self._remote_home}/.hermes"
|
||||
dirs = [base, f"{base}/skills", f"{base}/credentials", f"{base}/cache"]
|
||||
mkdir_cmd = "mkdir -p " + " ".join(shlex.quote(d) for d in dirs)
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(mkdir_cmd)
|
||||
cmd.append(quoted_mkdir_command(dirs))
|
||||
subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
|
||||
# _get_sync_files provided via iter_sync_files in FileSyncManager init
|
||||
@@ -131,6 +140,96 @@ class SSHEnvironment(BaseEnvironment):
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"scp failed: {result.stderr.strip()}")
|
||||
|
||||
def _ssh_bulk_upload(self, files: list[tuple[str, str]]) -> None:
|
||||
"""Upload many files in a single tar-over-SSH stream.
|
||||
|
||||
Pipes ``tar c`` on the local side through an SSH connection to
|
||||
``tar x`` on the remote, transferring all files in one TCP stream
|
||||
instead of spawning a subprocess per file. Directory creation is
|
||||
batched into a single ``mkdir -p`` call beforehand.
|
||||
|
||||
Typical improvement: ~580 files goes from O(N) scp round-trips
|
||||
to a single streaming transfer.
|
||||
"""
|
||||
if not files:
|
||||
return
|
||||
|
||||
parents = unique_parent_dirs(files)
|
||||
if parents:
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(quoted_mkdir_command(parents))
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"remote mkdir failed: {result.stderr.strip()}")
|
||||
|
||||
# Symlink staging avoids fragile GNU tar --transform rules.
|
||||
with tempfile.TemporaryDirectory(prefix="hermes-ssh-bulk-") as staging:
|
||||
for host_path, remote_path in files:
|
||||
staged = os.path.join(staging, remote_path.lstrip("/"))
|
||||
os.makedirs(os.path.dirname(staged), exist_ok=True)
|
||||
os.symlink(os.path.abspath(host_path), staged)
|
||||
|
||||
tar_cmd = ["tar", "-chf", "-", "-C", staging, "."]
|
||||
ssh_cmd = self._build_ssh_command()
|
||||
ssh_cmd.append("tar xf - -C /")
|
||||
|
||||
tar_proc = subprocess.Popen(
|
||||
tar_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
try:
|
||||
ssh_proc = subprocess.Popen(
|
||||
ssh_cmd, stdin=tar_proc.stdout, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
except Exception:
|
||||
tar_proc.kill()
|
||||
tar_proc.wait()
|
||||
raise
|
||||
|
||||
# Allow tar_proc to receive SIGPIPE if ssh_proc exits early
|
||||
tar_proc.stdout.close()
|
||||
|
||||
try:
|
||||
_, ssh_stderr = ssh_proc.communicate(timeout=120)
|
||||
# Use communicate() instead of wait() to drain stderr and
|
||||
# avoid deadlock if tar produces more than PIPE_BUF of errors.
|
||||
tar_stderr_raw = b""
|
||||
if tar_proc.poll() is None:
|
||||
_, tar_stderr_raw = tar_proc.communicate(timeout=10)
|
||||
else:
|
||||
tar_stderr_raw = tar_proc.stderr.read() if tar_proc.stderr else b""
|
||||
except subprocess.TimeoutExpired:
|
||||
tar_proc.kill()
|
||||
ssh_proc.kill()
|
||||
tar_proc.wait()
|
||||
ssh_proc.wait()
|
||||
raise RuntimeError("SSH bulk upload timed out")
|
||||
|
||||
if tar_proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"tar create failed (rc={tar_proc.returncode}): "
|
||||
f"{tar_stderr_raw.decode(errors='replace').strip()}"
|
||||
)
|
||||
if ssh_proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"tar extract over SSH failed (rc={ssh_proc.returncode}): "
|
||||
f"{ssh_stderr.decode(errors='replace').strip()}"
|
||||
)
|
||||
|
||||
logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files))
|
||||
|
||||
def _ssh_bulk_download(self, dest: Path) -> None:
|
||||
"""Download remote .hermes/ as a tar archive."""
|
||||
# Tar from / with the full path so archive entries preserve absolute
|
||||
# paths (e.g. home/user/.hermes/skills/f.py), matching _pushed_hashes keys.
|
||||
rel_base = f"{self._remote_home}/.hermes".lstrip("/")
|
||||
ssh_cmd = self._build_ssh_command()
|
||||
ssh_cmd.append(f"tar cf - -C / {shlex.quote(rel_base)}")
|
||||
with open(dest, "wb") as f:
|
||||
result = subprocess.run(ssh_cmd, stdout=f, stderr=subprocess.PIPE, timeout=120)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"SSH bulk download failed: {result.stderr.decode(errors='replace').strip()}")
|
||||
|
||||
def _ssh_delete(self, remote_paths: list[str]) -> None:
|
||||
"""Batch-delete remote files in one SSH call."""
|
||||
cmd = self._build_ssh_command()
|
||||
@@ -160,6 +259,10 @@ class SSHEnvironment(BaseEnvironment):
|
||||
return _popen_bash(cmd, stdin_data)
|
||||
|
||||
def cleanup(self):
|
||||
if self._sync_manager:
|
||||
logger.info("SSH: syncing files from sandbox...")
|
||||
self._sync_manager.sync_back()
|
||||
|
||||
if self.control_socket.exists():
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||
|
||||
Reference in New Issue
Block a user