Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 499c7346c5 | |||
| eb2127c1dc | |||
| 5a1e2a307a | |||
| 41d9d08078 | |||
| b7bcae49c6 | |||
| 915df02bbf | |||
| 75fcbc44ce | |||
| be416cdfa9 | |||
| b8b1f24fd7 | |||
| a2847ea7f0 | |||
| 58ca875e19 | |||
| 3f95e741a7 | |||
| 03396627a6 | |||
| 22cfad157b | |||
| 867eefdd9f | |||
| a8df7f9964 | |||
| 1519c4d477 | |||
| 005786c55d | |||
| ad764d3513 | |||
| f008ee1019 | |||
| 60fdb58ce4 | |||
| 18d28c63a7 | |||
| 3c57eaf744 | |||
| 2d232c9991 | |||
| 0375b2a0d7 | |||
| 08fa326bb0 | |||
| bde45f5a2a | |||
| 716e616d28 | |||
| bdccdd67a1 | |||
| 148f46620f |
@@ -59,6 +59,7 @@ _OAUTH_ONLY_BETAS = [
|
||||
# The version must stay reasonably current — Anthropic rejects OAuth requests
|
||||
# when the spoofed user-agent version is too far behind the actual release.
|
||||
_CLAUDE_CODE_VERSION_FALLBACK = "2.1.74"
|
||||
_claude_code_version_cache: Optional[str] = None
|
||||
|
||||
|
||||
def _detect_claude_code_version() -> str:
|
||||
@@ -86,11 +87,18 @@ def _detect_claude_code_version() -> str:
|
||||
return _CLAUDE_CODE_VERSION_FALLBACK
|
||||
|
||||
|
||||
_CLAUDE_CODE_VERSION = _detect_claude_code_version()
|
||||
_CLAUDE_CODE_SYSTEM_PREFIX = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
_MCP_TOOL_PREFIX = "mcp_"
|
||||
|
||||
|
||||
def _get_claude_code_version() -> str:
|
||||
"""Lazily detect the installed Claude Code version when OAuth headers need it."""
|
||||
global _claude_code_version_cache
|
||||
if _claude_code_version_cache is None:
|
||||
_claude_code_version_cache = _detect_claude_code_version()
|
||||
return _claude_code_version_cache
|
||||
|
||||
|
||||
def _is_oauth_token(key: str) -> bool:
|
||||
"""Check if the key is an OAuth/setup token (not a regular Console API key).
|
||||
|
||||
@@ -132,7 +140,7 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
kwargs["auth_token"] = api_key
|
||||
kwargs["default_headers"] = {
|
||||
"anthropic-beta": ",".join(all_betas),
|
||||
"user-agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
"user-agent": f"claude-cli/{_get_claude_code_version()} (external, cli)",
|
||||
"x-app": "cli",
|
||||
}
|
||||
else:
|
||||
@@ -241,7 +249,7 @@ def _refresh_oauth_token(creds: Dict[str, Any]) -> Optional[str]:
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
"User-Agent": f"claude-cli/{_get_claude_code_version()} (external, cli)",
|
||||
}
|
||||
|
||||
for endpoint in token_endpoints:
|
||||
@@ -706,14 +714,21 @@ def convert_messages_to_anthropic(
|
||||
result.append({"role": "user", "content": [tool_result]})
|
||||
continue
|
||||
|
||||
# Regular user message
|
||||
# Regular user message — validate non-empty content (Anthropic rejects empty)
|
||||
if isinstance(content, list):
|
||||
converted_blocks = _convert_content_to_anthropic(content)
|
||||
result.append({
|
||||
"role": "user",
|
||||
"content": converted_blocks or [{"type": "text", "text": ""}],
|
||||
})
|
||||
# Check if all text blocks are empty
|
||||
if not converted_blocks or all(
|
||||
b.get("text", "").strip() == ""
|
||||
for b in converted_blocks
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
):
|
||||
converted_blocks = [{"type": "text", "text": "(empty message)"}]
|
||||
result.append({"role": "user", "content": converted_blocks})
|
||||
else:
|
||||
# Validate string content is non-empty
|
||||
if not content or (isinstance(content, str) and not content.strip()):
|
||||
content = "(empty message)"
|
||||
result.append({"role": "user", "content": content})
|
||||
|
||||
# Strip orphaned tool_use blocks (no matching tool_result follows)
|
||||
|
||||
@@ -693,7 +693,13 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
||||
is_oauth = _is_oauth_token(token)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get("anthropic", "claude-haiku-4-5-20251001")
|
||||
logger.debug("Auxiliary client: Anthropic native (%s) at %s (oauth=%s)", model, base_url, is_oauth)
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
try:
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
except ImportError:
|
||||
# The anthropic_adapter module imports fine but the SDK itself is
|
||||
# missing — build_anthropic_client raises ImportError at call time
|
||||
# when _anthropic_sdk is None. Treat as unavailable.
|
||||
return None, None
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model
|
||||
|
||||
|
||||
@@ -1131,7 +1137,13 @@ def resolve_vision_provider_client(
|
||||
return "custom", client, final_model
|
||||
|
||||
if requested == "auto":
|
||||
for candidate in get_available_vision_backends():
|
||||
ordered = list(_VISION_AUTO_PROVIDER_ORDER)
|
||||
preferred = _preferred_main_vision_provider()
|
||||
if preferred in ordered:
|
||||
ordered.remove(preferred)
|
||||
ordered.insert(0, preferred)
|
||||
|
||||
for candidate in ordered:
|
||||
sync_client, default_model = _resolve_strict_vision_backend(candidate)
|
||||
if sync_client is not None:
|
||||
return _finalize(candidate, sync_client, default_model)
|
||||
|
||||
+16
-2
@@ -231,7 +231,7 @@ class KawaiiSpinner:
|
||||
"analyzing", "computing", "synthesizing", "formulating", "brainstorming",
|
||||
]
|
||||
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots'):
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots', print_fn=None):
|
||||
self.message = message
|
||||
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots'])
|
||||
self.running = False
|
||||
@@ -239,12 +239,26 @@ class KawaiiSpinner:
|
||||
self.frame_idx = 0
|
||||
self.start_time = None
|
||||
self.last_line_len = 0
|
||||
# Optional callable to route all output through (e.g. a no-op for silent
|
||||
# background agents). When set, bypasses self._out entirely so that
|
||||
# agents with _print_fn overridden remain fully silent.
|
||||
self._print_fn = print_fn
|
||||
# Capture stdout NOW, before any redirect_stdout(devnull) from
|
||||
# child agents can replace sys.stdout with a black hole.
|
||||
self._out = sys.stdout
|
||||
|
||||
def _write(self, text: str, end: str = '\n', flush: bool = False):
|
||||
"""Write to the stdout captured at spinner creation time."""
|
||||
"""Write to the stdout captured at spinner creation time.
|
||||
|
||||
If a print_fn was supplied at construction, all output is routed through
|
||||
it instead — allowing callers to silence the spinner with a no-op lambda.
|
||||
"""
|
||||
if self._print_fn is not None:
|
||||
try:
|
||||
self._print_fn(text)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
try:
|
||||
self._out.write(text + end)
|
||||
if flush:
|
||||
|
||||
@@ -688,6 +688,12 @@ display:
|
||||
# Toggle at runtime with /verbose in the CLI
|
||||
tool_progress: all
|
||||
|
||||
# What Enter does when Hermes is already busy in the CLI.
|
||||
# interrupt: Interrupt the current run and redirect Hermes (default)
|
||||
# queue: Queue your message for the next turn
|
||||
# Ctrl+C always interrupts regardless of this setting.
|
||||
busy_input_mode: interrupt
|
||||
|
||||
# Background process notifications (gateway/messaging only).
|
||||
# Controls how chatty the process watcher is when you use
|
||||
# terminal(background=true, check_interval=...) from Telegram/Discord/etc.
|
||||
|
||||
@@ -205,6 +205,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"resume_display": "full",
|
||||
"show_reasoning": False,
|
||||
"streaming": True,
|
||||
"busy_input_mode": "interrupt",
|
||||
|
||||
"skin": "default",
|
||||
},
|
||||
@@ -1035,13 +1036,18 @@ class HermesCLI:
|
||||
self.config = CLI_CONFIG
|
||||
self.compact = compact if compact is not None else CLI_CONFIG["display"].get("compact", False)
|
||||
# tool_progress: "off", "new", "all", "verbose" (from config.yaml display section)
|
||||
self.tool_progress_mode = CLI_CONFIG["display"].get("tool_progress", "all")
|
||||
# YAML 1.1 parses bare `off` as boolean False — normalise to string.
|
||||
_raw_tp = CLI_CONFIG["display"].get("tool_progress", "all")
|
||||
self.tool_progress_mode = "off" if _raw_tp is False else str(_raw_tp)
|
||||
# resume_display: "full" (show history) | "minimal" (one-liner only)
|
||||
self.resume_display = CLI_CONFIG["display"].get("resume_display", "full")
|
||||
# bell_on_complete: play terminal bell (\a) when agent finishes a response
|
||||
self.bell_on_complete = CLI_CONFIG["display"].get("bell_on_complete", False)
|
||||
# show_reasoning: display model thinking/reasoning before the response
|
||||
self.show_reasoning = CLI_CONFIG["display"].get("show_reasoning", False)
|
||||
# busy_input_mode: "interrupt" (Enter interrupts current run) or "queue" (Enter queues for next turn)
|
||||
_bim = CLI_CONFIG["display"].get("busy_input_mode", "interrupt")
|
||||
self.busy_input_mode = "queue" if str(_bim).strip().lower() == "queue" else "interrupt"
|
||||
|
||||
self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose")
|
||||
|
||||
@@ -1329,7 +1335,12 @@ class HermesCLI:
|
||||
def _build_status_bar_text(self, width: Optional[int] = None) -> str:
|
||||
try:
|
||||
snapshot = self._get_status_bar_snapshot()
|
||||
width = width or shutil.get_terminal_size((80, 24)).columns
|
||||
if width is None:
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
width = get_app().output.get_size().columns
|
||||
except Exception:
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
percent = snapshot["context_percent"]
|
||||
percent_label = f"{percent}%" if percent is not None else "--"
|
||||
duration_label = snapshot["duration"]
|
||||
@@ -1359,7 +1370,16 @@ class HermesCLI:
|
||||
return []
|
||||
try:
|
||||
snapshot = self._get_status_bar_snapshot()
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
# Use prompt_toolkit's own terminal width when running inside the
|
||||
# TUI — shutil.get_terminal_size() can return stale or fallback
|
||||
# values (especially on SSH) that differ from what prompt_toolkit
|
||||
# actually renders, causing the fragments to overflow to a second
|
||||
# line and produce duplicated status bar rows over long sessions.
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
width = get_app().output.get_size().columns
|
||||
except Exception:
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
duration_label = snapshot["duration"]
|
||||
|
||||
if width < 52:
|
||||
@@ -1594,6 +1614,7 @@ class HermesCLI:
|
||||
if not text:
|
||||
return
|
||||
self._reasoning_stream_started = True
|
||||
self._reasoning_shown_this_turn = True
|
||||
if getattr(self, "_stream_box_opened", False):
|
||||
return
|
||||
|
||||
@@ -2929,6 +2950,82 @@ class HermesCLI:
|
||||
if not silent:
|
||||
print("(^_^)v New session started!")
|
||||
|
||||
def _handle_resume_command(self, cmd_original: str) -> None:
|
||||
"""Handle /resume <session_id_or_title> — switch to a previous session mid-conversation."""
|
||||
parts = cmd_original.split(None, 1)
|
||||
target = parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
if not target:
|
||||
_cprint(" Usage: /resume <session_id_or_title>")
|
||||
_cprint(" Tip: Use /history or `hermes sessions list` to find sessions.")
|
||||
return
|
||||
|
||||
if not self._session_db:
|
||||
_cprint(" Session database not available.")
|
||||
return
|
||||
|
||||
# Resolve title or ID
|
||||
from hermes_cli.main import _resolve_session_by_name_or_id
|
||||
resolved = _resolve_session_by_name_or_id(target)
|
||||
target_id = resolved or target
|
||||
|
||||
session_meta = self._session_db.get_session(target_id)
|
||||
if not session_meta:
|
||||
_cprint(f" Session not found: {target}")
|
||||
_cprint(" Use /history or `hermes sessions list` to see available sessions.")
|
||||
return
|
||||
|
||||
if target_id == self.session_id:
|
||||
_cprint(" Already on that session.")
|
||||
return
|
||||
|
||||
# End current session
|
||||
try:
|
||||
self._session_db.end_session(self.session_id, "resumed_other")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Switch to the target session
|
||||
self.session_id = target_id
|
||||
self._resumed = True
|
||||
self._pending_title = None
|
||||
|
||||
# Load conversation history
|
||||
restored = self._session_db.get_messages_as_conversation(target_id)
|
||||
self.conversation_history = restored or []
|
||||
|
||||
# Re-open the target session so it's not marked as ended
|
||||
try:
|
||||
self._session_db.reopen_session(target_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Sync the agent if already initialised
|
||||
if self.agent:
|
||||
self.agent.session_id = target_id
|
||||
self.agent.reset_session_state()
|
||||
if hasattr(self.agent, "_last_flushed_db_idx"):
|
||||
self.agent._last_flushed_db_idx = len(self.conversation_history)
|
||||
if hasattr(self.agent, "_todo_store"):
|
||||
try:
|
||||
from tools.todo_tool import TodoStore
|
||||
self.agent._todo_store = TodoStore()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(self.agent, "_invalidate_system_prompt"):
|
||||
self.agent._invalidate_system_prompt()
|
||||
|
||||
title_part = f" \"{session_meta['title']}\"" if session_meta.get("title") else ""
|
||||
msg_count = len([m for m in self.conversation_history if m.get("role") == "user"])
|
||||
if self.conversation_history:
|
||||
_cprint(
|
||||
f" ↻ Resumed session {target_id}{title_part}"
|
||||
f" ({msg_count} user message{'s' if msg_count != 1 else ''},"
|
||||
f" {len(self.conversation_history)} total)"
|
||||
)
|
||||
else:
|
||||
_cprint(f" ↻ Resumed session {target_id}{title_part} — no messages, starting fresh.")
|
||||
|
||||
def reset_conversation(self):
|
||||
"""Reset the conversation by starting a new session."""
|
||||
self.new_session()
|
||||
@@ -3647,6 +3744,8 @@ class HermesCLI:
|
||||
_cprint(" Session database not available.")
|
||||
elif canonical == "new":
|
||||
self.new_session()
|
||||
elif canonical == "resume":
|
||||
self._handle_resume_command(cmd_original)
|
||||
elif canonical == "provider":
|
||||
self._show_model_and_providers()
|
||||
elif canonical == "prompt":
|
||||
@@ -3722,17 +3821,17 @@ class HermesCLI:
|
||||
elif canonical == "background":
|
||||
self._handle_background_command(cmd_original)
|
||||
elif canonical == "queue":
|
||||
if not self._agent_running:
|
||||
_cprint(" /queue only works while Hermes is busy. Just type your message normally.")
|
||||
# Extract prompt after "/queue " or "/q "
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /queue <prompt>")
|
||||
else:
|
||||
# Extract prompt after "/queue " or "/q "
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /queue <prompt>")
|
||||
else:
|
||||
self._pending_input.put(payload)
|
||||
self._pending_input.put(payload)
|
||||
if self._agent_running:
|
||||
_cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
else:
|
||||
_cprint(f" Queued: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "skin":
|
||||
self._handle_skin_command(cmd_original)
|
||||
elif canonical == "voice":
|
||||
@@ -5436,6 +5535,10 @@ class HermesCLI:
|
||||
|
||||
# Reset streaming display state for this turn
|
||||
self._reset_stream_state()
|
||||
# Separate from _reset_stream_state because this must persist
|
||||
# across intermediate turn boundaries (tool-calling loops) — only
|
||||
# reset at the start of each user turn.
|
||||
self._reasoning_shown_this_turn = False
|
||||
|
||||
# --- Streaming TTS setup ---
|
||||
# When ElevenLabs is the TTS provider and sounddevice is available,
|
||||
@@ -5640,8 +5743,13 @@ class HermesCLI:
|
||||
response_previewed = result.get("response_previewed", False) if result else False
|
||||
|
||||
# Display reasoning (thinking) box if enabled and available.
|
||||
# Skip when streaming already showed reasoning live.
|
||||
if self.show_reasoning and result and not self._reasoning_stream_started:
|
||||
# Skip when streaming already showed reasoning live. Use the
|
||||
# turn-persistent flag (_reasoning_shown_this_turn) instead of
|
||||
# _reasoning_stream_started — the latter gets reset during
|
||||
# intermediate turn boundaries (tool-calling loops), which caused
|
||||
# the reasoning box to re-render after the final response.
|
||||
_reasoning_already_shown = getattr(self, '_reasoning_shown_this_turn', False)
|
||||
if self.show_reasoning and result and not _reasoning_already_shown:
|
||||
reasoning = result.get("last_reasoning")
|
||||
if reasoning:
|
||||
w = shutil.get_terminal_size().columns
|
||||
@@ -6112,16 +6220,22 @@ class HermesCLI:
|
||||
# Bundle text + images as a tuple when images are present
|
||||
payload = (text, images) if images else text
|
||||
if self._agent_running and not (text and text.startswith("/")):
|
||||
self._interrupt_queue.put(payload)
|
||||
# Debug: log to file when message enters interrupt queue
|
||||
try:
|
||||
_dbg = _hermes_home / "interrupt_debug.log"
|
||||
with open(_dbg, "a") as _f:
|
||||
import time as _t
|
||||
_f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, "
|
||||
f"agent_running={self._agent_running}\n")
|
||||
except Exception:
|
||||
pass
|
||||
if self.busy_input_mode == "queue":
|
||||
# Queue for the next turn instead of interrupting
|
||||
self._pending_input.put(payload)
|
||||
preview = text if text else f"[{len(images)} image{'s' if len(images) != 1 else ''} attached]"
|
||||
_cprint(f" Queued for the next turn: {preview[:80]}{'...' if len(preview) > 80 else ''}")
|
||||
else:
|
||||
self._interrupt_queue.put(payload)
|
||||
# Debug: log to file when message enters interrupt queue
|
||||
try:
|
||||
_dbg = _hermes_home / "interrupt_debug.log"
|
||||
with open(_dbg, "a") as _f:
|
||||
import time as _t
|
||||
_f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, "
|
||||
f"agent_running={self._agent_running}\n")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self._pending_input.put(payload)
|
||||
event.app.current_buffer.reset(append_to_history=True)
|
||||
@@ -6894,6 +7008,15 @@ class HermesCLI:
|
||||
Window(
|
||||
content=FormattedTextControl(lambda: cli_ref._get_status_bar_fragments()),
|
||||
height=1,
|
||||
# Prevent fragments that overflow the terminal width from
|
||||
# wrapping onto a second line, which causes the status bar to
|
||||
# appear duplicated (one full + one partial row) during long
|
||||
# sessions, especially on SSH where shutil.get_terminal_size
|
||||
# may return stale values. _get_status_bar_fragments now reads
|
||||
# width from prompt_toolkit's own output object, so fragments
|
||||
# will always fit; wrap_lines=False is the belt-and-suspenders
|
||||
# guard against any future width mismatch.
|
||||
wrap_lines=False,
|
||||
),
|
||||
filter=Condition(lambda: cli_ref._status_bar_visible),
|
||||
)
|
||||
|
||||
@@ -598,6 +598,34 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
save_jobs(jobs)
|
||||
|
||||
|
||||
def advance_next_run(job_id: str) -> bool:
|
||||
"""Preemptively advance next_run_at for a recurring job before execution.
|
||||
|
||||
Call this BEFORE run_job() so that if the process crashes mid-execution,
|
||||
the job won't re-fire on the next gateway restart. This converts the
|
||||
scheduler from at-least-once to at-most-once for recurring jobs — missing
|
||||
one run is far better than firing dozens of times in a crash loop.
|
||||
|
||||
One-shot jobs are left unchanged so they can still retry on restart.
|
||||
|
||||
Returns True if next_run_at was advanced, False otherwise.
|
||||
"""
|
||||
jobs = load_jobs()
|
||||
for job in jobs:
|
||||
if job["id"] == job_id:
|
||||
kind = job.get("schedule", {}).get("kind")
|
||||
if kind not in ("cron", "interval"):
|
||||
return False
|
||||
now = _hermes_now().isoformat()
|
||||
new_next = compute_next_run(job["schedule"], now)
|
||||
if new_next and new_next != job.get("next_run_at"):
|
||||
job["next_run_at"] = new_next
|
||||
save_jobs(jobs)
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now.
|
||||
|
||||
|
||||
+7
-1
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
|
||||
# Sentinel: when a cron agent has nothing new to report, it can start its
|
||||
# response with this marker to suppress delivery. Output is still saved
|
||||
@@ -524,6 +524,12 @@ def tick(verbose: bool = True) -> int:
|
||||
executed = 0
|
||||
for job in due_jobs:
|
||||
try:
|
||||
# For recurring jobs (cron/interval), advance next_run_at to the
|
||||
# next future occurrence BEFORE execution. This way, if the
|
||||
# process crashes mid-run, the job won't re-fire on restart.
|
||||
# One-shot jobs are left alone so they can retry on restart.
|
||||
advance_next_run(job["id"])
|
||||
|
||||
success, output, final_response, error = run_job(job)
|
||||
|
||||
output_file = save_job_output(job["id"], output)
|
||||
|
||||
@@ -601,6 +601,14 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.TELEGRAM] = PlatformConfig()
|
||||
config.platforms[Platform.TELEGRAM].reply_to_mode = telegram_reply_mode
|
||||
|
||||
telegram_fallback_ips = os.getenv("TELEGRAM_FALLBACK_IPS", "")
|
||||
if telegram_fallback_ips:
|
||||
if Platform.TELEGRAM not in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM] = PlatformConfig()
|
||||
config.platforms[Platform.TELEGRAM].extra["fallback_ips"] = [
|
||||
ip.strip() for ip in telegram_fallback_ips.split(",") if ip.strip()
|
||||
]
|
||||
|
||||
telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL")
|
||||
if telegram_home and Platform.TELEGRAM in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM].home_channel = HomeChannel(
|
||||
|
||||
@@ -366,14 +366,20 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
Create an AIAgent instance using the gateway's runtime config.
|
||||
|
||||
Uses _resolve_runtime_agent_kwargs() to pick up model, api_key,
|
||||
base_url, etc. from config.yaml / env vars.
|
||||
base_url, etc. from config.yaml / env vars. Toolsets are resolved
|
||||
from config.yaml platform_toolsets.api_server (same as all other
|
||||
gateway platforms), falling back to the hermes-api-server default.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model
|
||||
from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model, _load_gateway_config
|
||||
from hermes_cli.tools_config import _get_platform_tools
|
||||
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
model = _resolve_gateway_model()
|
||||
|
||||
user_config = _load_gateway_config()
|
||||
enabled_toolsets = sorted(_get_platform_tools(user_config, "api_server"))
|
||||
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
|
||||
agent = AIAgent(
|
||||
@@ -383,7 +389,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
quiet_mode=True,
|
||||
verbose_logging=False,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt or None,
|
||||
enabled_toolsets=["hermes-api-server"],
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
session_id=session_id,
|
||||
platform="api_server",
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
|
||||
+136
-25
@@ -8,6 +8,7 @@ and implement the required methods.
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -71,31 +72,51 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
|
||||
return str(filepath)
|
||||
|
||||
|
||||
async def cache_image_from_url(url: str, ext: str = ".jpg") -> str:
|
||||
async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> str:
|
||||
"""
|
||||
Download an image from a URL and save it to the local cache.
|
||||
|
||||
Uses httpx for async download with a reasonable timeout.
|
||||
Retries on transient failures (timeouts, 429, 5xx) with exponential
|
||||
backoff so a single slow CDN response doesn't lose the media.
|
||||
|
||||
Args:
|
||||
url: The HTTP/HTTPS URL to download from.
|
||||
ext: File extension including the dot (e.g. ".jpg", ".png").
|
||||
retries: Number of retry attempts on transient failures.
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
"""
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging as _logging
|
||||
_log = _logging.getLogger(__name__)
|
||||
|
||||
last_exc = None
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < retries:
|
||||
wait = 1.5 * (attempt + 1)
|
||||
_log.debug("Media cache retry %d/%d for %s (%.1fs): %s",
|
||||
attempt + 1, retries, url[:80], wait, exc)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
|
||||
def cleanup_image_cache(max_age_hours: int = 24) -> int:
|
||||
@@ -329,6 +350,24 @@ class SendResult:
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
raw_response: Any = None
|
||||
retryable: bool = False # True for transient errors (network, timeout) — base will retry automatically
|
||||
|
||||
|
||||
# Error substrings that indicate a transient network failure worth retrying
|
||||
_RETRYABLE_ERROR_PATTERNS = (
|
||||
"connecterror",
|
||||
"connectionerror",
|
||||
"connectionreset",
|
||||
"connectionrefused",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"network",
|
||||
"broken pipe",
|
||||
"remotedisconnected",
|
||||
"eoferror",
|
||||
"readtimeout",
|
||||
"writetimeout",
|
||||
)
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
@@ -833,6 +872,91 @@ class BasePlatformAdapter(ABC):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_error(error: Optional[str]) -> bool:
|
||||
"""Return True if the error string looks like a transient network failure."""
|
||||
if not error:
|
||||
return False
|
||||
lowered = error.lower()
|
||||
return any(pat in lowered for pat in _RETRYABLE_ERROR_PATTERNS)
|
||||
|
||||
async def _send_with_retry(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Any = None,
|
||||
max_retries: int = 2,
|
||||
base_delay: float = 2.0,
|
||||
) -> "SendResult":
|
||||
"""
|
||||
Send a message with automatic retry for transient network errors.
|
||||
|
||||
On permanent failures (e.g. formatting / permission errors) falls back
|
||||
to a plain-text version before giving up. If all attempts fail due to
|
||||
network errors, sends the user a brief delivery-failure notice so they
|
||||
know to retry rather than waiting indefinitely.
|
||||
"""
|
||||
|
||||
result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if result.success:
|
||||
return result
|
||||
|
||||
error_str = result.error or ""
|
||||
is_network = result.retryable or self._is_retryable_error(error_str)
|
||||
|
||||
if is_network:
|
||||
# Retry with exponential backoff for transient errors
|
||||
for attempt in range(1, max_retries + 1):
|
||||
delay = base_delay * (2 ** (attempt - 1)) + random.uniform(0, 1)
|
||||
logger.warning(
|
||||
"[%s] Send failed (attempt %d/%d, retrying in %.1fs): %s",
|
||||
self.name, attempt, max_retries, delay, error_str,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
if result.success:
|
||||
logger.info("[%s] Send succeeded on retry %d", self.name, attempt)
|
||||
return result
|
||||
error_str = result.error or ""
|
||||
if not (result.retryable or self._is_retryable_error(error_str)):
|
||||
break # error switched to non-transient — fall through to plain-text fallback
|
||||
else:
|
||||
# All retries exhausted (loop completed without break) — notify user
|
||||
logger.error("[%s] Failed to deliver response after %d retries: %s", self.name, max_retries, error_str)
|
||||
notice = (
|
||||
"\u26a0\ufe0f Message delivery failed after multiple attempts. "
|
||||
"Please try again \u2014 your request was processed but the response could not be sent."
|
||||
)
|
||||
try:
|
||||
await self.send(chat_id=chat_id, content=notice, reply_to=reply_to, metadata=metadata)
|
||||
except Exception as notify_err:
|
||||
logger.debug("[%s] Could not send delivery-failure notice: %s", self.name, notify_err)
|
||||
return result
|
||||
|
||||
# Non-network / post-retry formatting failure: try plain text as fallback
|
||||
logger.warning("[%s] Send failed: %s — trying plain-text fallback", self.name, error_str)
|
||||
fallback_result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{content[:3500]}",
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
logger.error("[%s] Fallback send also failed: %s", self.name, fallback_result.error)
|
||||
return fallback_result
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
@@ -982,26 +1106,13 @@ class BasePlatformAdapter(ABC):
|
||||
# Send the text portion
|
||||
if text_content:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
result = await self._send_with_retry(
|
||||
chat_id=event.source.chat_id,
|
||||
content=text_content,
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
|
||||
@@ -2096,6 +2096,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if pending_text_injection:
|
||||
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection
|
||||
|
||||
# Defense-in-depth: prevent empty user messages from entering session
|
||||
# (can happen when user sends @mention-only with no other text)
|
||||
if not event_text or not event_text.strip():
|
||||
event_text = "(The user sent a message with no text content)"
|
||||
|
||||
event = MessageEvent(
|
||||
text=event_text,
|
||||
message_type=msg_type,
|
||||
|
||||
@@ -551,9 +551,20 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
"""Continuously sync with the homeserver."""
|
||||
import nio
|
||||
|
||||
while not self._closing:
|
||||
try:
|
||||
await self._client.sync(timeout=30000)
|
||||
resp = await self._client.sync(timeout=30000)
|
||||
if isinstance(resp, nio.SyncError):
|
||||
if self._closing:
|
||||
return
|
||||
logger.warning(
|
||||
"Matrix: sync returned %s: %s — retrying in 5s",
|
||||
type(resp).__name__,
|
||||
getattr(resp, "message", resp),
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
|
||||
@@ -407,18 +407,38 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
kind: str = "file",
|
||||
) -> SendResult:
|
||||
"""Download a URL and upload it as a file attachment."""
|
||||
import asyncio
|
||||
import aiohttp
|
||||
try:
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status >= 400:
|
||||
# Fall back to sending the URL as text.
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
file_data = await resp.read()
|
||||
ct = resp.content_type or "application/octet-stream"
|
||||
# Derive filename from URL.
|
||||
fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png"
|
||||
except Exception as exc:
|
||||
logger.warning("Mattermost: failed to download %s: %s", url, exc)
|
||||
|
||||
last_exc = None
|
||||
file_data = None
|
||||
ct = "application/octet-stream"
|
||||
fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png"
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status >= 500 or resp.status == 429:
|
||||
if attempt < 2:
|
||||
logger.debug("Mattermost download retry %d/2 for %s (status %d)",
|
||||
attempt + 1, url[:80], resp.status)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
if resp.status >= 400:
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
file_data = await resp.read()
|
||||
ct = resp.content_type or "application/octet-stream"
|
||||
break
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||
last_exc = exc
|
||||
if attempt < 2:
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
logger.warning("Mattermost: failed to download %s after %d attempts: %s", url, attempt + 1, exc)
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
if file_data is None:
|
||||
logger.warning("Mattermost: download returned no data for %s", url)
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
file_id = await self._upload_file(chat_id, file_data, fname, ct)
|
||||
|
||||
@@ -279,6 +279,12 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# SSE keepalive comments (":") prove the connection
|
||||
# is alive — update activity so the health monitor
|
||||
# doesn't report false idle warnings.
|
||||
if line.startswith(":"):
|
||||
self._last_sse_activity = time.time()
|
||||
continue
|
||||
# Parse SSE data lines
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
|
||||
+52
-20
@@ -819,33 +819,65 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _download_slack_file(self, url: str, ext: str, audio: bool = False) -> str:
|
||||
"""Download a Slack file using the bot token for auth."""
|
||||
"""Download a Slack file using the bot token for auth, with retry."""
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
last_exc = None
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < 2:
|
||||
logger.debug("Slack file download retry %d/2 for %s: %s",
|
||||
attempt + 1, url[:80], exc)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str) -> bytes:
|
||||
"""Download a Slack file and return raw bytes."""
|
||||
"""Download a Slack file and return raw bytes, with retry."""
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
last_exc = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < 2:
|
||||
logger.debug("Slack file download retry %d/2 for %s: %s",
|
||||
attempt + 1, url[:80], exc)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
@@ -11,7 +11,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +25,7 @@ try:
|
||||
filters,
|
||||
)
|
||||
from telegram.constants import ParseMode, ChatType
|
||||
from telegram.request import HTTPXRequest
|
||||
TELEGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TELEGRAM_AVAILABLE = False
|
||||
@@ -34,6 +35,7 @@ except ImportError:
|
||||
Application = Any
|
||||
CommandHandler = Any
|
||||
TelegramMessageHandler = Any
|
||||
HTTPXRequest = Any
|
||||
filters = None
|
||||
ParseMode = None
|
||||
ChatType = None
|
||||
@@ -59,6 +61,11 @@ from gateway.platforms.base import (
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
from gateway.platforms.telegram_network import (
|
||||
TelegramFallbackTransport,
|
||||
discover_fallback_ips,
|
||||
parse_fallback_ip_env,
|
||||
)
|
||||
|
||||
|
||||
def check_telegram_requirements() -> bool:
|
||||
@@ -138,6 +145,13 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# DM Topics config from extra.dm_topics
|
||||
self._dm_topics_config: List[Dict[str, Any]] = self.config.extra.get("dm_topics", [])
|
||||
|
||||
def _fallback_ips(self) -> list[str]:
|
||||
"""Return validated fallback IPs from config (populated by _apply_env_overrides)."""
|
||||
configured = self.config.extra.get("fallback_ips", []) if getattr(self.config, "extra", None) else []
|
||||
if isinstance(configured, str):
|
||||
configured = configured.split(",")
|
||||
return parse_fallback_ip_env(",".join(str(v) for v in configured) if configured else None)
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_polling_conflict(error: Exception) -> bool:
|
||||
text = str(error).lower()
|
||||
@@ -474,7 +488,26 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
# Build the application
|
||||
self._app = Application.builder().token(self.config.token).build()
|
||||
builder = Application.builder().token(self.config.token)
|
||||
fallback_ips = self._fallback_ips()
|
||||
if not fallback_ips:
|
||||
fallback_ips = await discover_fallback_ips()
|
||||
logger.info(
|
||||
"[%s] Auto-discovered Telegram fallback IPs: %s",
|
||||
self.name,
|
||||
", ".join(fallback_ips),
|
||||
)
|
||||
if fallback_ips:
|
||||
logger.warning(
|
||||
"[%s] Telegram fallback IPs active: %s",
|
||||
self.name,
|
||||
", ".join(fallback_ips),
|
||||
)
|
||||
transport = TelegramFallbackTransport(fallback_ips)
|
||||
request = HTTPXRequest(httpx_kwargs={"transport": transport})
|
||||
get_updates_request = HTTPXRequest(httpx_kwargs={"transport": transport})
|
||||
builder = builder.request(request).get_updates_request(get_updates_request)
|
||||
self._app = builder.build()
|
||||
self._bot = self._app.bot
|
||||
|
||||
# Register handlers
|
||||
@@ -674,9 +707,15 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except ImportError:
|
||||
_NetErr = OSError # type: ignore[misc,assignment]
|
||||
|
||||
try:
|
||||
from telegram.error import BadRequest as _BadReq
|
||||
except ImportError:
|
||||
_BadReq = None # type: ignore[assignment,misc]
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
should_thread = self._should_thread_reply(reply_to, i)
|
||||
reply_to_id = int(reply_to) if should_thread else None
|
||||
effective_thread_id = int(thread_id) if thread_id else None
|
||||
|
||||
msg = None
|
||||
for _send_attempt in range(3):
|
||||
@@ -688,7 +727,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
text=chunk,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
reply_to_message_id=reply_to_id,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
message_thread_id=effective_thread_id,
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
@@ -700,12 +739,30 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
text=plain_chunk,
|
||||
parse_mode=None,
|
||||
reply_to_message_id=reply_to_id,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
message_thread_id=effective_thread_id,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
break # success
|
||||
except _NetErr as send_err:
|
||||
# BadRequest is a subclass of NetworkError in
|
||||
# python-telegram-bot but represents permanent errors
|
||||
# (not transient network issues). Detect and handle
|
||||
# specific cases instead of blindly retrying.
|
||||
if _BadReq and isinstance(send_err, _BadReq):
|
||||
err_lower = str(send_err).lower()
|
||||
if "thread not found" in err_lower and effective_thread_id is not None:
|
||||
# Thread doesn't exist — retry without
|
||||
# message_thread_id so the message still
|
||||
# reaches the chat.
|
||||
logger.warning(
|
||||
"[%s] Thread %s not found, retrying without message_thread_id",
|
||||
self.name, effective_thread_id,
|
||||
)
|
||||
effective_thread_id = None
|
||||
continue
|
||||
# Other BadRequest errors are permanent — don't retry
|
||||
raise
|
||||
if _send_attempt < 2:
|
||||
wait = 2 ** _send_attempt
|
||||
logger.warning("[%s] Network error on send (attempt %d/3), retrying in %ds: %s",
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
"""Telegram-specific network helpers.
|
||||
|
||||
Provides a hostname-preserving fallback transport for networks where
|
||||
api.telegram.org resolves to an endpoint that is unreachable from the current
|
||||
host. The transport keeps the logical request host and TLS SNI as
|
||||
api.telegram.org while retrying the TCP connection against one or more fallback
|
||||
IPv4 addresses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TELEGRAM_API_HOST = "api.telegram.org"
|
||||
|
||||
# DNS-over-HTTPS providers used to discover Telegram API IPs that may differ
|
||||
# from the (potentially unreachable) IP returned by the local system resolver.
|
||||
_DOH_TIMEOUT = 4.0 # seconds — bounded so connect() isn't noticeably delayed
|
||||
|
||||
_DOH_PROVIDERS: list[dict] = [
|
||||
{
|
||||
"url": "https://dns.google/resolve",
|
||||
"params": {"name": _TELEGRAM_API_HOST, "type": "A"},
|
||||
"headers": {},
|
||||
},
|
||||
{
|
||||
"url": "https://cloudflare-dns.com/dns-query",
|
||||
"params": {"name": _TELEGRAM_API_HOST, "type": "A"},
|
||||
"headers": {"Accept": "application/dns-json"},
|
||||
},
|
||||
]
|
||||
|
||||
# Last-resort IPs when DoH is also blocked. These are stable Telegram Bot API
|
||||
# endpoints in the 149.154.160.0/20 block (same seed used by OpenClaw).
|
||||
_SEED_FALLBACK_IPS: list[str] = ["149.154.167.220"]
|
||||
|
||||
|
||||
class TelegramFallbackTransport(httpx.AsyncBaseTransport):
|
||||
"""Retry Telegram Bot API requests via fallback IPs while preserving TLS/SNI.
|
||||
|
||||
Requests continue to target https://api.telegram.org/... logically, but on
|
||||
connect failures the underlying TCP connection is retried against a known
|
||||
reachable IP. This is effectively the programmatic equivalent of
|
||||
``curl --resolve api.telegram.org:443:<ip>``.
|
||||
"""
|
||||
|
||||
def __init__(self, fallback_ips: Iterable[str], **transport_kwargs):
|
||||
self._fallback_ips = [ip for ip in dict.fromkeys(_normalize_fallback_ips(fallback_ips))]
|
||||
self._primary = httpx.AsyncHTTPTransport(**transport_kwargs)
|
||||
self._fallbacks = {
|
||||
ip: httpx.AsyncHTTPTransport(**transport_kwargs) for ip in self._fallback_ips
|
||||
}
|
||||
self._sticky_ip: Optional[str] = None
|
||||
self._sticky_lock = asyncio.Lock()
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if request.url.host != _TELEGRAM_API_HOST or not self._fallback_ips:
|
||||
return await self._primary.handle_async_request(request)
|
||||
|
||||
sticky_ip = self._sticky_ip
|
||||
attempt_order: list[Optional[str]] = [sticky_ip] if sticky_ip else [None]
|
||||
for ip in self._fallback_ips:
|
||||
if ip != sticky_ip:
|
||||
attempt_order.append(ip)
|
||||
|
||||
last_error: Exception | None = None
|
||||
for ip in attempt_order:
|
||||
candidate = request if ip is None else _rewrite_request_for_ip(request, ip)
|
||||
transport = self._primary if ip is None else self._fallbacks[ip]
|
||||
try:
|
||||
response = await transport.handle_async_request(candidate)
|
||||
if ip is not None and self._sticky_ip != ip:
|
||||
async with self._sticky_lock:
|
||||
if self._sticky_ip != ip:
|
||||
self._sticky_ip = ip
|
||||
logger.warning(
|
||||
"[Telegram] Primary api.telegram.org path unreachable; using sticky fallback IP %s",
|
||||
ip,
|
||||
)
|
||||
return response
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if not _is_retryable_connect_error(exc):
|
||||
raise
|
||||
if ip is None:
|
||||
logger.warning(
|
||||
"[Telegram] Primary api.telegram.org connection failed (%s); trying fallback IPs %s",
|
||||
exc,
|
||||
", ".join(self._fallback_ips),
|
||||
)
|
||||
continue
|
||||
logger.warning("[Telegram] Fallback IP %s failed: %s", ip, exc)
|
||||
continue
|
||||
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._primary.aclose()
|
||||
for transport in self._fallbacks.values():
|
||||
await transport.aclose()
|
||||
|
||||
|
||||
def _normalize_fallback_ips(values: Iterable[str]) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
for value in values:
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
addr = ipaddress.ip_address(raw)
|
||||
except ValueError:
|
||||
logger.warning("Ignoring invalid Telegram fallback IP: %r", raw)
|
||||
continue
|
||||
if addr.version != 4:
|
||||
logger.warning("Ignoring non-IPv4 Telegram fallback IP: %s", raw)
|
||||
continue
|
||||
normalized.append(str(addr))
|
||||
return normalized
|
||||
|
||||
|
||||
def parse_fallback_ip_env(value: str | None) -> list[str]:
|
||||
if not value:
|
||||
return []
|
||||
parts = [part.strip() for part in value.split(",")]
|
||||
return _normalize_fallback_ips(parts)
|
||||
|
||||
|
||||
def _resolve_system_dns() -> set[str]:
|
||||
"""Return the IPv4 addresses that the OS resolver gives for api.telegram.org."""
|
||||
try:
|
||||
results = socket.getaddrinfo(_TELEGRAM_API_HOST, 443, socket.AF_INET)
|
||||
return {addr[4][0] for addr in results}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
async def _query_doh_provider(
|
||||
client: httpx.AsyncClient, provider: dict
|
||||
) -> list[str]:
|
||||
"""Query one DoH provider and return A-record IPs."""
|
||||
try:
|
||||
resp = await client.get(
|
||||
provider["url"], params=provider["params"], headers=provider["headers"]
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
ips: list[str] = []
|
||||
for answer in data.get("Answer", []):
|
||||
if answer.get("type") != 1: # A record
|
||||
continue
|
||||
raw = answer.get("data", "").strip()
|
||||
try:
|
||||
ipaddress.ip_address(raw)
|
||||
ips.append(raw)
|
||||
except ValueError:
|
||||
continue
|
||||
return ips
|
||||
except Exception as exc:
|
||||
logger.debug("DoH query to %s failed: %s", provider["url"], exc)
|
||||
return []
|
||||
|
||||
|
||||
async def discover_fallback_ips() -> list[str]:
|
||||
"""Auto-discover Telegram API IPs via DNS-over-HTTPS.
|
||||
|
||||
Resolves api.telegram.org through Google and Cloudflare DoH, collects all
|
||||
unique IPs, and excludes the system-DNS-resolved IP (which is presumably
|
||||
unreachable on this network). Falls back to a hardcoded seed list when DoH
|
||||
is also unavailable.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(_DOH_TIMEOUT)) as client:
|
||||
doh_tasks = [_query_doh_provider(client, p) for p in _DOH_PROVIDERS]
|
||||
system_dns_task = asyncio.to_thread(_resolve_system_dns)
|
||||
results = await asyncio.gather(system_dns_task, *doh_tasks, return_exceptions=True)
|
||||
|
||||
# results[0] = system DNS IPs (set), results[1:] = DoH IP lists
|
||||
system_ips: set[str] = results[0] if isinstance(results[0], set) else set()
|
||||
|
||||
doh_ips: list[str] = []
|
||||
for r in results[1:]:
|
||||
if isinstance(r, list):
|
||||
doh_ips.extend(r)
|
||||
|
||||
# Deduplicate preserving order, exclude system-DNS IPs
|
||||
seen: set[str] = set()
|
||||
candidates: list[str] = []
|
||||
for ip in doh_ips:
|
||||
if ip not in seen and ip not in system_ips:
|
||||
seen.add(ip)
|
||||
candidates.append(ip)
|
||||
|
||||
# Validate through existing normalization
|
||||
validated = _normalize_fallback_ips(candidates)
|
||||
|
||||
if validated:
|
||||
logger.debug("Discovered Telegram fallback IPs via DoH: %s", ", ".join(validated))
|
||||
return validated
|
||||
|
||||
logger.info(
|
||||
"DoH discovery yielded no new IPs (system DNS: %s); using seed fallback IPs %s",
|
||||
", ".join(system_ips) or "unknown",
|
||||
", ".join(_SEED_FALLBACK_IPS),
|
||||
)
|
||||
return list(_SEED_FALLBACK_IPS)
|
||||
|
||||
|
||||
def _rewrite_request_for_ip(request: httpx.Request, ip: str) -> httpx.Request:
|
||||
original_host = request.url.host or _TELEGRAM_API_HOST
|
||||
url = request.url.copy_with(host=ip)
|
||||
headers = request.headers.copy()
|
||||
headers["host"] = original_host
|
||||
extensions = dict(request.extensions)
|
||||
extensions["sni_hostname"] = original_host
|
||||
return httpx.Request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
stream=request.stream,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
|
||||
def _is_retryable_connect_error(exc: Exception) -> bool:
|
||||
return isinstance(exc, (httpx.ConnectTimeout, httpx.ConnectError))
|
||||
+139
-7
@@ -573,6 +573,10 @@ class GatewayRunner:
|
||||
session_id=old_session_id,
|
||||
honcho_session_key=honcho_session_key,
|
||||
)
|
||||
# Fully silence the flush agent — quiet_mode only suppresses init
|
||||
# messages; tool call output still leaks to the terminal through
|
||||
# _safe_print → _print_fn. Set a no-op to prevent that.
|
||||
tmp_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
# Build conversation history from transcript
|
||||
msgs = [
|
||||
@@ -954,12 +958,20 @@ class GatewayRunner:
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "EMAIL_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
os.getenv(v, "").lower() in ("true", "1", "yes")
|
||||
for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS",
|
||||
"WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS",
|
||||
"SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS")
|
||||
)
|
||||
if not _any_allowlist and not _allow_all:
|
||||
logger.warning(
|
||||
"No user allowlists configured. All unauthorized users will be denied. "
|
||||
@@ -1970,6 +1982,12 @@ class GatewayRunner:
|
||||
f"Use /resume to browse and restore a previous session.\n"
|
||||
f"Adjust reset timing in config.yaml under session_reset."
|
||||
)
|
||||
try:
|
||||
session_info = self._format_session_info()
|
||||
if session_info:
|
||||
notice = f"{notice}\n\n{session_info}"
|
||||
except Exception:
|
||||
pass
|
||||
await adapter.send(
|
||||
source.chat_id, notice,
|
||||
metadata=getattr(event, 'metadata', None),
|
||||
@@ -2175,6 +2193,7 @@ class GatewayRunner:
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
_hyg_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
_compressed, _ = await loop.run_in_executor(
|
||||
@@ -2736,6 +2755,85 @@ class GatewayRunner:
|
||||
# Clear session env
|
||||
self._clear_session_env()
|
||||
|
||||
def _format_session_info(self) -> str:
|
||||
"""Resolve current model config and return a formatted info block.
|
||||
|
||||
Surfaces model, provider, context length, and endpoint so gateway
|
||||
users can immediately see if context detection went wrong (e.g.
|
||||
local models falling to the 128K default).
|
||||
"""
|
||||
from agent.model_metadata import get_model_context_length, DEFAULT_FALLBACK_CONTEXT
|
||||
|
||||
model = _resolve_gateway_model()
|
||||
config_context_length = None
|
||||
provider = None
|
||||
base_url = None
|
||||
api_key = None
|
||||
|
||||
try:
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
import yaml as _info_yaml
|
||||
with open(cfg_path, encoding="utf-8") as f:
|
||||
data = _info_yaml.safe_load(f) or {}
|
||||
model_cfg = data.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
raw_ctx = model_cfg.get("context_length")
|
||||
if raw_ctx is not None:
|
||||
try:
|
||||
config_context_length = int(raw_ctx)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
provider = model_cfg.get("provider") or None
|
||||
base_url = model_cfg.get("base_url") or None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Resolve runtime credentials for probing
|
||||
try:
|
||||
runtime = _resolve_runtime_agent_kwargs()
|
||||
provider = provider or runtime.get("provider")
|
||||
base_url = base_url or runtime.get("base_url")
|
||||
api_key = runtime.get("api_key")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
context_length = get_model_context_length(
|
||||
model,
|
||||
base_url=base_url or "",
|
||||
api_key=api_key or "",
|
||||
config_context_length=config_context_length,
|
||||
provider=provider or "",
|
||||
)
|
||||
|
||||
# Format context source hint
|
||||
if config_context_length is not None:
|
||||
ctx_source = "config"
|
||||
elif context_length == DEFAULT_FALLBACK_CONTEXT:
|
||||
ctx_source = "default — set model.context_length in config to override"
|
||||
else:
|
||||
ctx_source = "detected"
|
||||
|
||||
# Format context length for display
|
||||
if context_length >= 1_000_000:
|
||||
ctx_display = f"{context_length / 1_000_000:.1f}M"
|
||||
elif context_length >= 1_000:
|
||||
ctx_display = f"{context_length // 1_000}K"
|
||||
else:
|
||||
ctx_display = str(context_length)
|
||||
|
||||
lines = [
|
||||
f"◆ Model: `{model}`",
|
||||
f"◆ Provider: {provider or 'openrouter'}",
|
||||
f"◆ Context: {ctx_display} tokens ({ctx_source})",
|
||||
]
|
||||
|
||||
# Show endpoint for local/custom setups
|
||||
if base_url and ("localhost" in base_url or "127.0.0.1" in base_url or "0.0.0.0" in base_url):
|
||||
lines.append(f"◆ Endpoint: {base_url}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_reset_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /new or /reset command."""
|
||||
source = event.source
|
||||
@@ -2776,12 +2874,22 @@ class GatewayRunner:
|
||||
"session_key": session_key,
|
||||
})
|
||||
|
||||
# Resolve session config info to surface to the user
|
||||
try:
|
||||
session_info = self._format_session_info()
|
||||
except Exception:
|
||||
session_info = ""
|
||||
|
||||
if new_entry:
|
||||
return "✨ Session reset! I've started fresh with no memory of our previous conversation."
|
||||
header = "✨ Session reset! Starting fresh."
|
||||
else:
|
||||
# No existing session, just create one
|
||||
self.session_store.get_or_create_session(source, force_new=True)
|
||||
return "✨ New session started!"
|
||||
header = "✨ New session started!"
|
||||
|
||||
if session_info:
|
||||
return f"{header}\n\n{session_info}"
|
||||
return header
|
||||
|
||||
async def _handle_status_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /status command."""
|
||||
@@ -3885,6 +3993,7 @@ class GatewayRunner:
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
tmp_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
compressed, _ = await loop.run_in_executor(
|
||||
@@ -4799,9 +4908,14 @@ class GatewayRunner:
|
||||
enabled_toolsets = sorted(_get_platform_tools(user_config, platform_key))
|
||||
|
||||
# Tool progress mode from config.yaml: "all", "new", "verbose", "off"
|
||||
# Falls back to env vars for backward compatibility
|
||||
# Falls back to env vars for backward compatibility.
|
||||
# YAML 1.1 parses bare `off` as boolean False — normalise before
|
||||
# the `or` chain so it doesn't silently fall through to "all".
|
||||
_raw_tp = user_config.get("display", {}).get("tool_progress")
|
||||
if _raw_tp is False:
|
||||
_raw_tp = "off"
|
||||
progress_mode = (
|
||||
user_config.get("display", {}).get("tool_progress")
|
||||
_raw_tp
|
||||
or os.getenv("HERMES_TOOL_PROGRESS_MODE")
|
||||
or "all"
|
||||
)
|
||||
@@ -5128,7 +5242,25 @@ class GatewayRunner:
|
||||
agent.stream_delta_callback = _stream_delta_cb
|
||||
agent.status_callback = _status_callback_sync
|
||||
agent.reasoning_config = reasoning_config
|
||||
|
||||
|
||||
# Background review delivery — send "💾 Memory updated" etc. to user
|
||||
def _bg_review_send(message: str) -> None:
|
||||
if not _status_adapter:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
_status_adapter.send(
|
||||
_status_chat_id,
|
||||
message,
|
||||
metadata=_status_thread_metadata,
|
||||
),
|
||||
_loop_for_step,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("background_review_callback error: %s", _e)
|
||||
|
||||
agent.background_review_callback = _bg_review_send
|
||||
|
||||
# Store agent reference for interrupt support
|
||||
agent_holder[0] = agent
|
||||
# Capture the full tool definitions for transcript logging
|
||||
|
||||
+14
-7
@@ -762,14 +762,16 @@ class SessionStore:
|
||||
if session_key in self._entries:
|
||||
entry = self._entries[session_key]
|
||||
entry.updated_at = _now()
|
||||
entry.input_tokens += input_tokens
|
||||
entry.output_tokens += output_tokens
|
||||
entry.cache_read_tokens += cache_read_tokens
|
||||
entry.cache_write_tokens += cache_write_tokens
|
||||
# Direct assignment — the gateway receives cumulative totals
|
||||
# from the cached agent, not per-call deltas.
|
||||
entry.input_tokens = input_tokens
|
||||
entry.output_tokens = output_tokens
|
||||
entry.cache_read_tokens = cache_read_tokens
|
||||
entry.cache_write_tokens = cache_write_tokens
|
||||
if last_prompt_tokens is not None:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
if estimated_cost_usd is not None:
|
||||
entry.estimated_cost_usd += estimated_cost_usd
|
||||
entry.estimated_cost_usd = estimated_cost_usd
|
||||
if cost_status:
|
||||
entry.cost_status = cost_status
|
||||
entry.total_tokens = (
|
||||
@@ -783,7 +785,7 @@ class SessionStore:
|
||||
|
||||
if self._db and db_session_id:
|
||||
try:
|
||||
self._db.update_token_counts(
|
||||
self._db.set_token_counts(
|
||||
db_session_id,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
@@ -795,6 +797,7 @@ class SessionStore:
|
||||
billing_provider=provider,
|
||||
billing_base_url=base_url,
|
||||
model=model,
|
||||
absolute=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
@@ -955,13 +958,17 @@ class SessionStore:
|
||||
try:
|
||||
self._db.clear_messages(session_id)
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
self._db.append_message(
|
||||
session_id=session_id,
|
||||
role=msg.get("role", "unknown"),
|
||||
role=role,
|
||||
content=msg.get("content"),
|
||||
tool_name=msg.get("tool_name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
reasoning=msg.get("reasoning") if role == "assistant" else None,
|
||||
reasoning_details=msg.get("reasoning_details") if role == "assistant" else None,
|
||||
codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
||||
|
||||
@@ -264,6 +264,7 @@ DEFAULT_CONFIG = {
|
||||
"compact": False,
|
||||
"personality": "kawaii",
|
||||
"resume_display": "full",
|
||||
"busy_input_mode": "interrupt",
|
||||
"bell_on_complete": False,
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
|
||||
+110
-7
@@ -2968,6 +2968,95 @@ def setup_tools(config: dict, first_install: bool = False):
|
||||
tools_command(first_install=first_install, config=config)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Post-Migration Section Skip Logic
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]:
|
||||
"""Return a short summary if a setup section is already configured, else None.
|
||||
|
||||
Used after OpenClaw migration to detect which sections can be skipped.
|
||||
``get_env_value`` is the module-level import from hermes_cli.config
|
||||
so that test patches on ``setup_mod.get_env_value`` take effect.
|
||||
"""
|
||||
if section_key == "model":
|
||||
has_key = bool(
|
||||
get_env_value("OPENROUTER_API_KEY")
|
||||
or get_env_value("OPENAI_API_KEY")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
)
|
||||
if not has_key:
|
||||
# Check for OAuth providers
|
||||
try:
|
||||
from hermes_cli.auth import get_active_provider
|
||||
if get_active_provider():
|
||||
has_key = True
|
||||
except Exception:
|
||||
pass
|
||||
if not has_key:
|
||||
return None
|
||||
model = config.get("model")
|
||||
if isinstance(model, str) and model.strip():
|
||||
return model.strip()
|
||||
if isinstance(model, dict):
|
||||
return str(model.get("default") or model.get("model") or "configured")
|
||||
return "configured"
|
||||
|
||||
elif section_key == "terminal":
|
||||
backend = config.get("terminal", {}).get("backend", "local")
|
||||
return f"backend: {backend}"
|
||||
|
||||
elif section_key == "agent":
|
||||
max_turns = config.get("agent", {}).get("max_turns", 90)
|
||||
return f"max turns: {max_turns}"
|
||||
|
||||
elif section_key == "gateway":
|
||||
platforms = []
|
||||
if get_env_value("TELEGRAM_BOT_TOKEN"):
|
||||
platforms.append("Telegram")
|
||||
if get_env_value("DISCORD_BOT_TOKEN"):
|
||||
platforms.append("Discord")
|
||||
if get_env_value("SLACK_BOT_TOKEN"):
|
||||
platforms.append("Slack")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("SIGNAL_ACCOUNT"):
|
||||
platforms.append("Signal")
|
||||
if platforms:
|
||||
return ", ".join(platforms)
|
||||
return None # No platforms configured — section must run
|
||||
|
||||
elif section_key == "tools":
|
||||
tools = []
|
||||
if get_env_value("ELEVENLABS_API_KEY"):
|
||||
tools.append("TTS/ElevenLabs")
|
||||
if get_env_value("BROWSERBASE_API_KEY"):
|
||||
tools.append("Browser")
|
||||
if get_env_value("FIRECRAWL_API_KEY"):
|
||||
tools.append("Firecrawl")
|
||||
if tools:
|
||||
return ", ".join(tools)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _skip_configured_section(
|
||||
config: dict, section_key: str, label: str
|
||||
) -> bool:
|
||||
"""Show an already-configured section summary and offer to skip.
|
||||
|
||||
Returns True if the user chose to skip, False if the section should run.
|
||||
"""
|
||||
summary = _get_section_config_summary(config, section_key)
|
||||
if not summary:
|
||||
return False
|
||||
print()
|
||||
print_success(f" {label}: {summary}")
|
||||
return not prompt_yes_no(f" Reconfigure {label.lower()}?", default=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenClaw Migration
|
||||
# =============================================================================
|
||||
@@ -3039,7 +3128,7 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=True,
|
||||
workspace_target=None,
|
||||
overwrite=False,
|
||||
overwrite=True,
|
||||
migrate_secrets=True,
|
||||
output_dir=None,
|
||||
selected_options=selected,
|
||||
@@ -3195,6 +3284,8 @@ def run_setup_wizard(args):
|
||||
)
|
||||
)
|
||||
|
||||
migration_ran = False
|
||||
|
||||
if is_existing:
|
||||
# ── Returning User Menu ──
|
||||
print()
|
||||
@@ -3264,7 +3355,8 @@ def run_setup_wizard(args):
|
||||
return
|
||||
|
||||
# Offer OpenClaw migration before configuration begins
|
||||
if _offer_openclaw_migration(hermes_home):
|
||||
migration_ran = _offer_openclaw_migration(hermes_home)
|
||||
if migration_ran:
|
||||
# Reload config in case migration wrote to it
|
||||
config = load_config()
|
||||
|
||||
@@ -3277,20 +3369,31 @@ def run_setup_wizard(args):
|
||||
print()
|
||||
print_info("You can edit these files directly or use 'hermes config edit'")
|
||||
|
||||
if migration_ran:
|
||||
print()
|
||||
print_info("Settings were imported from OpenClaw.")
|
||||
print_info("Each section below will show what was imported — press Enter to keep,")
|
||||
print_info("or choose to reconfigure if needed.")
|
||||
|
||||
# Section 1: Model & Provider
|
||||
setup_model_provider(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "model", "Model & Provider")):
|
||||
setup_model_provider(config)
|
||||
|
||||
# Section 2: Terminal Backend
|
||||
setup_terminal_backend(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "terminal", "Terminal Backend")):
|
||||
setup_terminal_backend(config)
|
||||
|
||||
# Section 3: Agent Settings
|
||||
setup_agent_settings(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "agent", "Agent Settings")):
|
||||
setup_agent_settings(config)
|
||||
|
||||
# Section 4: Messaging Platforms
|
||||
setup_gateway(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "gateway", "Messaging Platforms")):
|
||||
setup_gateway(config)
|
||||
|
||||
# Section 5: Tools
|
||||
setup_tools(config, first_install=not is_existing)
|
||||
if not (migration_ran and _skip_configured_section(config, "tools", "Tools")):
|
||||
setup_tools(config, first_install=not is_existing)
|
||||
|
||||
# Save and show summary
|
||||
save_config(config)
|
||||
|
||||
+304
-84
@@ -15,15 +15,20 @@ Key design decisions:
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
DEFAULT_DB_PATH = get_hermes_home() / "state.db"
|
||||
|
||||
@@ -116,18 +121,38 @@ class SessionDB:
|
||||
single writer via WAL mode). Each method opens its own cursor.
|
||||
"""
|
||||
|
||||
# ── Write-contention tuning ──
|
||||
# With multiple hermes processes (gateway + CLI sessions + worktree agents)
|
||||
# all sharing one state.db, WAL write-lock contention causes visible TUI
|
||||
# freezes. SQLite's built-in busy handler uses a deterministic sleep
|
||||
# schedule that causes convoy effects under high concurrency.
|
||||
#
|
||||
# Instead, we keep the SQLite timeout short (1s) and handle retries at the
|
||||
# application level with random jitter, which naturally staggers competing
|
||||
# writers and avoids the convoy.
|
||||
_WRITE_MAX_RETRIES = 15
|
||||
_WRITE_RETRY_MIN_S = 0.020 # 20ms
|
||||
_WRITE_RETRY_MAX_S = 0.150 # 150ms
|
||||
# Attempt a PASSIVE WAL checkpoint every N successful writes.
|
||||
_CHECKPOINT_EVERY_N_WRITES = 50
|
||||
|
||||
def __init__(self, db_path: Path = None):
|
||||
self.db_path = db_path or DEFAULT_DB_PATH
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._write_count = 0
|
||||
self._conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
check_same_thread=False,
|
||||
# 30s gives the WAL writer (CLI or gateway) time to finish a batch
|
||||
# flush before the concurrent reader/writer gives up. 10s was too
|
||||
# short when the CLI is doing frequent memory flushes.
|
||||
timeout=30.0,
|
||||
# Short timeout — application-level retry with random jitter
|
||||
# handles contention instead of sitting in SQLite's internal
|
||||
# busy handler for up to 30s.
|
||||
timeout=1.0,
|
||||
# Autocommit mode: Python's default isolation_level="" auto-starts
|
||||
# transactions on DML, which conflicts with our explicit
|
||||
# BEGIN IMMEDIATE. None = we manage transactions ourselves.
|
||||
isolation_level=None,
|
||||
)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
@@ -135,6 +160,96 @@ class SessionDB:
|
||||
|
||||
self._init_schema()
|
||||
|
||||
# ── Core write helper ──
|
||||
|
||||
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
|
||||
"""Execute a write transaction with BEGIN IMMEDIATE and jitter retry.
|
||||
|
||||
*fn* receives the connection and should perform INSERT/UPDATE/DELETE
|
||||
statements. The caller must NOT call ``commit()`` — that's handled
|
||||
here after *fn* returns.
|
||||
|
||||
BEGIN IMMEDIATE acquires the WAL write lock at transaction start
|
||||
(not at commit time), so lock contention surfaces immediately.
|
||||
On ``database is locked``, we release the Python lock, sleep a
|
||||
random 20-150ms, and retry — breaking the convoy pattern that
|
||||
SQLite's built-in deterministic backoff creates.
|
||||
|
||||
Returns whatever *fn* returns.
|
||||
"""
|
||||
last_err: Optional[Exception] = None
|
||||
for attempt in range(self._WRITE_MAX_RETRIES):
|
||||
try:
|
||||
with self._lock:
|
||||
self._conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
result = fn(self._conn)
|
||||
self._conn.commit()
|
||||
except BaseException:
|
||||
try:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
# Success — periodic best-effort checkpoint.
|
||||
self._write_count += 1
|
||||
if self._write_count % self._CHECKPOINT_EVERY_N_WRITES == 0:
|
||||
self._try_wal_checkpoint()
|
||||
return result
|
||||
except sqlite3.OperationalError as exc:
|
||||
err_msg = str(exc).lower()
|
||||
if "locked" in err_msg or "busy" in err_msg:
|
||||
last_err = exc
|
||||
if attempt < self._WRITE_MAX_RETRIES - 1:
|
||||
jitter = random.uniform(
|
||||
self._WRITE_RETRY_MIN_S,
|
||||
self._WRITE_RETRY_MAX_S,
|
||||
)
|
||||
time.sleep(jitter)
|
||||
continue
|
||||
# Non-lock error or retries exhausted — propagate.
|
||||
raise
|
||||
# Retries exhausted (shouldn't normally reach here).
|
||||
raise last_err or sqlite3.OperationalError(
|
||||
"database is locked after max retries"
|
||||
)
|
||||
|
||||
def _try_wal_checkpoint(self) -> None:
|
||||
"""Best-effort PASSIVE WAL checkpoint. Never blocks, never raises.
|
||||
|
||||
Flushes committed WAL frames back into the main DB file for any
|
||||
frames that no other connection currently needs. Keeps the WAL
|
||||
from growing unbounded when many processes hold persistent
|
||||
connections.
|
||||
"""
|
||||
try:
|
||||
with self._lock:
|
||||
result = self._conn.execute(
|
||||
"PRAGMA wal_checkpoint(PASSIVE)"
|
||||
).fetchone()
|
||||
if result and result[1] > 0:
|
||||
logger.debug(
|
||||
"WAL checkpoint: %d/%d pages checkpointed",
|
||||
result[2], result[1],
|
||||
)
|
||||
except Exception:
|
||||
pass # Best effort — never fatal.
|
||||
|
||||
def close(self):
|
||||
"""Close the database connection.
|
||||
|
||||
Attempts a PASSIVE WAL checkpoint first so that exiting processes
|
||||
help keep the WAL file from growing unbounded.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._conn:
|
||||
try:
|
||||
self._conn.execute("PRAGMA wal_checkpoint(PASSIVE)")
|
||||
except Exception:
|
||||
pass
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
def _init_schema(self):
|
||||
"""Create tables and FTS if they don't exist, run migrations."""
|
||||
cursor = self._conn.cursor()
|
||||
@@ -256,8 +371,8 @@ class SessionDB:
|
||||
parent_session_id: str = None,
|
||||
) -> str:
|
||||
"""Create a new session record. Returns the session_id."""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""INSERT OR IGNORE INTO sessions (id, source, user_id, model, model_config,
|
||||
system_prompt, parent_session_id, started_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
@@ -272,26 +387,35 @@ class SessionDB:
|
||||
time.time(),
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._execute_write(_do)
|
||||
return session_id
|
||||
|
||||
def end_session(self, session_id: str, end_reason: str) -> None:
|
||||
"""Mark a session as ended."""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?",
|
||||
(time.time(), end_reason, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._execute_write(_do)
|
||||
|
||||
def reopen_session(self, session_id: str) -> None:
|
||||
"""Clear ended_at/end_reason so a session can be resumed."""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"UPDATE sessions SET ended_at = NULL, end_reason = NULL WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||
"""Store the full assembled system prompt snapshot."""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"UPDATE sessions SET system_prompt = ? WHERE id = ?",
|
||||
(system_prompt, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._execute_write(_do)
|
||||
|
||||
def update_token_counts(
|
||||
self,
|
||||
@@ -310,11 +434,39 @@ class SessionDB:
|
||||
billing_provider: Optional[str] = None,
|
||||
billing_base_url: Optional[str] = None,
|
||||
billing_mode: Optional[str] = None,
|
||||
absolute: bool = False,
|
||||
) -> None:
|
||||
"""Increment token counters and backfill model if not already set."""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
"""Update token counters and backfill model if not already set.
|
||||
|
||||
When *absolute* is False (default), values are **incremented** — use
|
||||
this for per-API-call deltas (CLI path).
|
||||
|
||||
When *absolute* is True, values are **set directly** — use this when
|
||||
the caller already holds cumulative totals (gateway path, where the
|
||||
cached agent accumulates across messages).
|
||||
"""
|
||||
if absolute:
|
||||
sql = """UPDATE sessions SET
|
||||
input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = COALESCE(?, 0),
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE ?
|
||||
END,
|
||||
cost_status = COALESCE(?, cost_status),
|
||||
cost_source = COALESCE(?, cost_source),
|
||||
pricing_version = COALESCE(?, pricing_version),
|
||||
billing_provider = COALESCE(billing_provider, ?),
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?"""
|
||||
else:
|
||||
sql = """UPDATE sessions SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
cache_read_tokens = cache_read_tokens + ?,
|
||||
@@ -332,6 +484,94 @@ class SessionDB:
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?"""
|
||||
params = (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
actual_cost_usd,
|
||||
cost_status,
|
||||
cost_source,
|
||||
pricing_version,
|
||||
billing_provider,
|
||||
billing_base_url,
|
||||
billing_mode,
|
||||
model,
|
||||
session_id,
|
||||
)
|
||||
def _do(conn):
|
||||
conn.execute(sql, params)
|
||||
self._execute_write(_do)
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
source: str = "unknown",
|
||||
model: str = None,
|
||||
) -> None:
|
||||
"""Ensure a session row exists, creating it with minimal metadata if absent.
|
||||
|
||||
Used by _flush_messages_to_session_db to recover from a failed
|
||||
create_session() call (e.g. transient SQLite lock at agent startup).
|
||||
INSERT OR IGNORE is safe to call even when the row already exists.
|
||||
"""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""INSERT OR IGNORE INTO sessions
|
||||
(id, source, model, started_at)
|
||||
VALUES (?, ?, ?, ?)""",
|
||||
(session_id, source, model, time.time()),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
def set_token_counts(
|
||||
self,
|
||||
session_id: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
model: str = None,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
estimated_cost_usd: Optional[float] = None,
|
||||
actual_cost_usd: Optional[float] = None,
|
||||
cost_status: Optional[str] = None,
|
||||
cost_source: Optional[str] = None,
|
||||
pricing_version: Optional[str] = None,
|
||||
billing_provider: Optional[str] = None,
|
||||
billing_base_url: Optional[str] = None,
|
||||
billing_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Set token counters to absolute values (not increment).
|
||||
|
||||
Use this when the caller provides cumulative totals from a completed
|
||||
conversation run (e.g. the gateway, where the cached agent's
|
||||
session_prompt_tokens already reflects the running total).
|
||||
"""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = ?,
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE ?
|
||||
END,
|
||||
cost_status = COALESCE(?, cost_status),
|
||||
cost_source = COALESCE(?, cost_source),
|
||||
pricing_version = COALESCE(?, pricing_version),
|
||||
billing_provider = COALESCE(billing_provider, ?),
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?""",
|
||||
(
|
||||
input_tokens,
|
||||
@@ -352,28 +592,7 @@ class SessionDB:
|
||||
session_id,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
source: str = "unknown",
|
||||
model: str = None,
|
||||
) -> None:
|
||||
"""Ensure a session row exists, creating it with minimal metadata if absent.
|
||||
|
||||
Used by _flush_messages_to_session_db to recover from a failed
|
||||
create_session() call (e.g. transient SQLite lock at agent startup).
|
||||
INSERT OR IGNORE is safe to call even when the row already exists.
|
||||
"""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"""INSERT OR IGNORE INTO sessions
|
||||
(id, source, model, started_at)
|
||||
VALUES (?, ?, ?, ?)""",
|
||||
(session_id, source, model, time.time()),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._execute_write(_do)
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a session by ID."""
|
||||
@@ -467,10 +686,10 @@ class SessionDB:
|
||||
Empty/whitespace-only strings are normalized to None (clearing the title).
|
||||
"""
|
||||
title = self.sanitize_title(title)
|
||||
with self._lock:
|
||||
def _do(conn):
|
||||
if title:
|
||||
# Check uniqueness (allow the same session to keep its own title)
|
||||
cursor = self._conn.execute(
|
||||
cursor = conn.execute(
|
||||
"SELECT id FROM sessions WHERE title = ? AND id != ?",
|
||||
(title, session_id),
|
||||
)
|
||||
@@ -479,12 +698,12 @@ class SessionDB:
|
||||
raise ValueError(
|
||||
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
cursor = conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
rowcount = cursor.rowcount
|
||||
return cursor.rowcount
|
||||
rowcount = self._execute_write(_do)
|
||||
return rowcount > 0
|
||||
|
||||
def get_session_title(self, session_id: str) -> Optional[str]:
|
||||
@@ -656,17 +875,24 @@ class SessionDB:
|
||||
Also increments the session's message_count (and tool_call_count
|
||||
if role is 'tool' or tool_calls is present).
|
||||
"""
|
||||
with self._lock:
|
||||
# Serialize structured fields to JSON for storage
|
||||
reasoning_details_json = (
|
||||
json.dumps(reasoning_details)
|
||||
if reasoning_details else None
|
||||
)
|
||||
codex_items_json = (
|
||||
json.dumps(codex_reasoning_items)
|
||||
if codex_reasoning_items else None
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
# Serialize structured fields to JSON before entering the write txn
|
||||
reasoning_details_json = (
|
||||
json.dumps(reasoning_details)
|
||||
if reasoning_details else None
|
||||
)
|
||||
codex_items_json = (
|
||||
json.dumps(codex_reasoning_items)
|
||||
if codex_reasoning_items else None
|
||||
)
|
||||
tool_calls_json = json.dumps(tool_calls) if tool_calls else None
|
||||
|
||||
# Pre-compute tool call count
|
||||
num_tool_calls = 0
|
||||
if tool_calls is not None:
|
||||
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
|
||||
|
||||
def _do(conn):
|
||||
cursor = conn.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, tool_call_id,
|
||||
tool_calls, tool_name, timestamp, token_count, finish_reason,
|
||||
reasoning, reasoning_details, codex_reasoning_items)
|
||||
@@ -676,7 +902,7 @@ class SessionDB:
|
||||
role,
|
||||
content,
|
||||
tool_call_id,
|
||||
json.dumps(tool_calls) if tool_calls else None,
|
||||
tool_calls_json,
|
||||
tool_name,
|
||||
time.time(),
|
||||
token_count,
|
||||
@@ -689,25 +915,20 @@ class SessionDB:
|
||||
msg_id = cursor.lastrowid
|
||||
|
||||
# Update counters
|
||||
# Count actual tool calls from the tool_calls list (not from tool responses).
|
||||
# A single assistant message can contain multiple parallel tool calls.
|
||||
num_tool_calls = 0
|
||||
if tool_calls is not None:
|
||||
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
|
||||
if num_tool_calls > 0:
|
||||
self._conn.execute(
|
||||
conn.execute(
|
||||
"""UPDATE sessions SET message_count = message_count + 1,
|
||||
tool_call_count = tool_call_count + ? WHERE id = ?""",
|
||||
(num_tool_calls, session_id),
|
||||
)
|
||||
else:
|
||||
self._conn.execute(
|
||||
conn.execute(
|
||||
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
return msg_id
|
||||
|
||||
self._conn.commit()
|
||||
return msg_id
|
||||
return self._execute_write(_do)
|
||||
|
||||
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""Load all messages for a session, ordered by timestamp."""
|
||||
@@ -1001,54 +1222,53 @@ class SessionDB:
|
||||
|
||||
def clear_messages(self, session_id: str) -> None:
|
||||
"""Delete all messages for a session and reset its counters."""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
self._conn.execute(
|
||||
conn.execute(
|
||||
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._execute_write(_do)
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session and all its messages. Returns True if found."""
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
def _do(conn):
|
||||
cursor = conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return False
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
self._conn.commit()
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
return True
|
||||
return self._execute_write(_do)
|
||||
|
||||
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
|
||||
"""
|
||||
Delete sessions older than N days. Returns count of deleted sessions.
|
||||
Only prunes ended sessions (not active ones).
|
||||
"""
|
||||
import time as _time
|
||||
cutoff = _time.time() - (older_than_days * 86400)
|
||||
cutoff = time.time() - (older_than_days * 86400)
|
||||
|
||||
with self._lock:
|
||||
def _do(conn):
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
cursor = conn.execute(
|
||||
"""SELECT id FROM sessions
|
||||
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
cursor = conn.execute(
|
||||
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
|
||||
(cutoff,),
|
||||
)
|
||||
session_ids = [row["id"] for row in cursor.fetchall()]
|
||||
|
||||
for sid in session_ids:
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||
return len(session_ids)
|
||||
|
||||
self._conn.commit()
|
||||
return len(session_ids)
|
||||
return self._execute_write(_do)
|
||||
|
||||
+1
-1
@@ -55,7 +55,7 @@ honcho = ["honcho-ai>=2.0.1,<3"]
|
||||
mcp = ["mcp>=1.2.0,<2"]
|
||||
homeassistant = ["aiohttp>=3.9.0,<4"]
|
||||
sms = ["aiohttp>=3.9.0,<4"]
|
||||
acp = ["agent-client-protocol>=0.8.1,<1.0"]
|
||||
acp = ["agent-client-protocol>=0.8.1,<0.9"]
|
||||
dingtalk = ["dingtalk-stream>=0.1.0,<1"]
|
||||
rl = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git",
|
||||
|
||||
+99
-13
@@ -62,7 +62,12 @@ else:
|
||||
|
||||
|
||||
# Import our tool system
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
from model_tools import (
|
||||
get_tool_definitions,
|
||||
get_toolset_for_tool,
|
||||
handle_function_call,
|
||||
check_toolset_requirements,
|
||||
)
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.interrupt import set_interrupt as _set_interrupt
|
||||
from tools.browser_tool import cleanup_browser
|
||||
@@ -486,6 +491,7 @@ class AIAgent:
|
||||
# instead of going directly to stdout where patch_stdout's StdoutProxy
|
||||
# would mangle the escape sequences. None = use builtins.print.
|
||||
self._print_fn = None
|
||||
self.background_review_callback = None # Optional sync callback for gateway delivery
|
||||
self.skip_context_files = skip_context_files
|
||||
self.pass_session_id = pass_session_id
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
@@ -533,6 +539,7 @@ class AIAgent:
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self.thinking_callback = thinking_callback
|
||||
self.reasoning_callback = reasoning_callback
|
||||
self._reasoning_deltas_fired = False # Set by _fire_reasoning_delta, reset per API call
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
@@ -1525,6 +1532,12 @@ class AIAgent:
|
||||
if actions:
|
||||
summary = " · ".join(dict.fromkeys(actions))
|
||||
self._safe_print(f" 💾 {summary}")
|
||||
_bg_cb = self.background_review_callback
|
||||
if _bg_cb:
|
||||
try:
|
||||
_bg_cb(f"💾 {summary}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Background memory/skill review failed: %s", e)
|
||||
@@ -2048,6 +2061,23 @@ class AIAgent:
|
||||
msg["content"] = self._clean_session_content(msg["content"])
|
||||
cleaned.append(msg)
|
||||
|
||||
# Guard: never overwrite a larger session log with fewer messages.
|
||||
# This protects against data loss when --resume loads a session whose
|
||||
# messages weren't fully written to SQLite — the resumed agent starts
|
||||
# with partial history and would otherwise clobber the full JSON log.
|
||||
if self.session_log_file.exists():
|
||||
try:
|
||||
existing = json.loads(self.session_log_file.read_text(encoding="utf-8"))
|
||||
existing_count = existing.get("message_count", len(existing.get("messages", [])))
|
||||
if existing_count > len(cleaned):
|
||||
logging.debug(
|
||||
"Skipping session log overwrite: existing has %d messages, current has %d",
|
||||
existing_count, len(cleaned),
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass # corrupted existing file — allow the overwrite
|
||||
|
||||
entry = {
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
@@ -2496,7 +2526,13 @@ class AIAgent:
|
||||
|
||||
has_skills_tools = any(name in self.valid_tool_names for name in ['skills_list', 'skill_view', 'skill_manage'])
|
||||
if has_skills_tools:
|
||||
avail_toolsets = {ts for ts, avail in check_toolset_requirements().items() if avail}
|
||||
avail_toolsets = {
|
||||
toolset
|
||||
for toolset in (
|
||||
get_toolset_for_tool(tool_name) for tool_name in self.valid_tool_names
|
||||
)
|
||||
if toolset
|
||||
}
|
||||
skills_prompt = build_skills_system_prompt(
|
||||
available_tools=self.valid_tool_names,
|
||||
available_toolsets=avail_toolsets,
|
||||
@@ -3380,6 +3416,7 @@ class AIAgent:
|
||||
max_stream_retries = 1
|
||||
has_tool_calls = False
|
||||
first_delta_fired = False
|
||||
self._reasoning_deltas_fired = False
|
||||
for attempt in range(max_stream_retries + 1):
|
||||
try:
|
||||
with active_client.responses.stream(**api_kwargs) as stream:
|
||||
@@ -3656,6 +3693,7 @@ class AIAgent:
|
||||
|
||||
def _fire_reasoning_delta(self, text: str) -> None:
|
||||
"""Fire reasoning callback if registered."""
|
||||
self._reasoning_deltas_fired = True
|
||||
cb = self.reasoning_callback
|
||||
if cb is not None:
|
||||
try:
|
||||
@@ -3750,6 +3788,9 @@ class AIAgent:
|
||||
request_client_holder["client"] = self._create_request_openai_client(
|
||||
reason="chat_completion_stream_request"
|
||||
)
|
||||
# Reset stale-stream timer so the detector measures from this
|
||||
# attempt's start, not a previous attempt's last chunk.
|
||||
last_chunk_time["t"] = time.time()
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts: list = []
|
||||
@@ -3760,6 +3801,9 @@ class AIAgent:
|
||||
role = "assistant"
|
||||
reasoning_parts: list = []
|
||||
usage_obj = None
|
||||
# Reset per-call reasoning tracking so _build_assistant_message
|
||||
# knows whether reasoning was already displayed during streaming.
|
||||
self._reasoning_deltas_fired = False
|
||||
|
||||
for chunk in stream:
|
||||
last_chunk_time["t"] = time.time()
|
||||
@@ -3879,7 +3923,10 @@ class AIAgent:
|
||||
works unchanged.
|
||||
"""
|
||||
has_tool_use = False
|
||||
self._reasoning_deltas_fired = False
|
||||
|
||||
# Reset stale-stream timer for this attempt
|
||||
last_chunk_time["t"] = time.time()
|
||||
# Use the Anthropic SDK's streaming context manager
|
||||
with self._anthropic_client.messages.stream(**api_kwargs) as stream:
|
||||
for event in stream:
|
||||
@@ -3992,6 +4039,10 @@ class AIAgent:
|
||||
)
|
||||
|
||||
try:
|
||||
# Reset stale timer — the non-streaming fallback
|
||||
# uses its own client; prevent the stale detector
|
||||
# from firing on stale timestamps from failed streams.
|
||||
last_chunk_time["t"] = time.time()
|
||||
result["response"] = self._interruptible_api_call(api_kwargs)
|
||||
except Exception as fallback_err:
|
||||
result["error"] = fallback_err
|
||||
@@ -4001,7 +4052,19 @@ class AIAgent:
|
||||
if request_client is not None:
|
||||
self._close_request_openai_client(request_client, reason="stream_request_complete")
|
||||
|
||||
_stream_stale_timeout = float(os.getenv("HERMES_STREAM_STALE_TIMEOUT", 90.0))
|
||||
_stream_stale_timeout_base = float(os.getenv("HERMES_STREAM_STALE_TIMEOUT", 180.0))
|
||||
# Scale the stale timeout for large contexts: slow models (like Opus)
|
||||
# can legitimately think for minutes before producing the first token
|
||||
# when the context is large. Without this, the stale detector kills
|
||||
# healthy connections during the model's thinking phase, producing
|
||||
# spurious RemoteProtocolError ("peer closed connection").
|
||||
_est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4
|
||||
if _est_tokens > 100_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 300.0)
|
||||
elif _est_tokens > 50_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 240.0)
|
||||
else:
|
||||
_stream_stale_timeout = _stream_stale_timeout_base
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
@@ -4127,6 +4190,25 @@ class AIAgent:
|
||||
or is_native_anthropic
|
||||
)
|
||||
|
||||
# Update context compressor limits for the fallback model.
|
||||
# Without this, compression decisions use the primary model's
|
||||
# context window (e.g. 200K) instead of the fallback's (e.g. 32K),
|
||||
# causing oversized sessions to overflow the fallback.
|
||||
if hasattr(self, 'context_compressor') and self.context_compressor:
|
||||
from agent.model_metadata import get_model_context_length
|
||||
fb_context_length = get_model_context_length(
|
||||
self.model, base_url=self.base_url,
|
||||
api_key=self.api_key, provider=self.provider,
|
||||
)
|
||||
self.context_compressor.model = self.model
|
||||
self.context_compressor.base_url = self.base_url
|
||||
self.context_compressor.api_key = self.api_key
|
||||
self.context_compressor.provider = self.provider
|
||||
self.context_compressor.context_length = fb_context_length
|
||||
self.context_compressor.threshold_tokens = int(
|
||||
fb_context_length * self.context_compressor.threshold_percent
|
||||
)
|
||||
|
||||
self._emit_status(
|
||||
f"🔄 Primary model failed — switching to fallback: "
|
||||
f"{fb_model} via {fb_provider}"
|
||||
@@ -4555,11 +4637,15 @@ class AIAgent:
|
||||
logging.debug(f"Captured reasoning ({len(reasoning_text)} chars): {reasoning_text}")
|
||||
|
||||
if reasoning_text and self.reasoning_callback:
|
||||
# Skip callback for <think>-extracted reasoning when streaming is active.
|
||||
# _stream_delta() already displayed <think> blocks during streaming;
|
||||
# firing the callback again would cause duplicate display.
|
||||
# Structured reasoning (from reasoning_content field) always fires.
|
||||
if _from_structured or not self.stream_delta_callback:
|
||||
# Skip callback when streaming is active — reasoning was already
|
||||
# displayed during the stream via one of two paths:
|
||||
# (a) _fire_reasoning_delta (structured reasoning_content deltas)
|
||||
# (b) _stream_delta tag extraction (<think>/<REASONING_SCRATCHPAD>)
|
||||
# When streaming is NOT active, always fire so non-streaming modes
|
||||
# (gateway, batch, quiet) still get reasoning.
|
||||
# Any reasoning that wasn't shown during streaming is caught by the
|
||||
# CLI post-response display fallback (cli.py _reasoning_shown_this_turn).
|
||||
if not self.stream_delta_callback:
|
||||
try:
|
||||
self.reasoning_callback(reasoning_text)
|
||||
except Exception:
|
||||
@@ -5080,7 +5166,7 @@ class AIAgent:
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots')
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
|
||||
try:
|
||||
@@ -5121,7 +5207,7 @@ class AIAgent:
|
||||
# Print cute message per tool
|
||||
if self.quiet_mode:
|
||||
cute_msg = _get_cute_tool_message_impl(name, args, tool_duration, result=function_result)
|
||||
print(f" {cute_msg}")
|
||||
self._safe_print(f" {cute_msg}")
|
||||
elif not self.quiet_mode:
|
||||
if self.verbose_logging:
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s")
|
||||
@@ -5306,7 +5392,7 @@ class AIAgent:
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots')
|
||||
spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
self._delegate_spinner = spinner
|
||||
_delegate_result = None
|
||||
@@ -5336,7 +5422,7 @@ class AIAgent:
|
||||
preview = _build_tool_preview(function_name, function_args) or function_name
|
||||
if len(preview) > 30:
|
||||
preview = preview[:27] + "..."
|
||||
spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots')
|
||||
spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
_spinner_result = None
|
||||
try:
|
||||
@@ -6019,7 +6105,7 @@ class AIAgent:
|
||||
# Raw KawaiiSpinner only when no streaming consumers
|
||||
# (would conflict with streamed token output)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type, print_fn=self._print_fn)
|
||||
thinking_spinner.start()
|
||||
|
||||
# Log request details if verbose
|
||||
|
||||
@@ -11,6 +11,7 @@ from agent.auxiliary_client import (
|
||||
get_text_auxiliary_client,
|
||||
get_vision_auxiliary_client,
|
||||
get_available_vision_backends,
|
||||
resolve_vision_provider_client,
|
||||
resolve_provider_client,
|
||||
auxiliary_max_tokens_param,
|
||||
_read_codex_access_token,
|
||||
@@ -638,6 +639,30 @@ class TestVisionClientFallback:
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
|
||||
def test_selected_codex_provider_short_circuits_vision_auto(self, monkeypatch):
|
||||
def fake_load_config():
|
||||
return {"model": {"provider": "openai-codex", "default": "gpt-5.2-codex"}}
|
||||
|
||||
codex_client = MagicMock()
|
||||
with (
|
||||
patch("hermes_cli.config.load_config", fake_load_config),
|
||||
patch("agent.auxiliary_client._try_codex", return_value=(codex_client, "gpt-5.2-codex")) as mock_codex,
|
||||
patch("agent.auxiliary_client._try_openrouter") as mock_openrouter,
|
||||
patch("agent.auxiliary_client._try_nous") as mock_nous,
|
||||
patch("agent.auxiliary_client._try_anthropic") as mock_anthropic,
|
||||
patch("agent.auxiliary_client._try_custom_endpoint") as mock_custom,
|
||||
):
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "openai-codex"
|
||||
assert client is codex_client
|
||||
assert model == "gpt-5.2-codex"
|
||||
mock_codex.assert_called_once()
|
||||
mock_openrouter.assert_not_called()
|
||||
mock_nous.assert_not_called()
|
||||
mock_anthropic.assert_not_called()
|
||||
mock_custom.assert_not_called()
|
||||
|
||||
def test_vision_auto_includes_codex(self, codex_auth_dir):
|
||||
"""Codex supports vision (gpt-5.3-codex), so auto mode should use it."""
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
|
||||
@@ -20,6 +20,7 @@ from cron.jobs import (
|
||||
resume_job,
|
||||
remove_job,
|
||||
mark_job_run,
|
||||
advance_next_run,
|
||||
get_due_jobs,
|
||||
save_job_output,
|
||||
)
|
||||
@@ -339,6 +340,90 @@ class TestMarkJobRun:
|
||||
assert updated["last_error"] == "timeout"
|
||||
|
||||
|
||||
class TestAdvanceNextRun:
|
||||
"""Tests for advance_next_run() — crash-safety for recurring jobs."""
|
||||
|
||||
def test_advances_interval_job(self, tmp_cron_dir):
|
||||
"""Interval jobs should have next_run_at bumped to the next future occurrence."""
|
||||
job = create_job(prompt="Recurring check", schedule="every 1h")
|
||||
# Force next_run_at to 5 minutes ago (i.e. the job is due)
|
||||
jobs = load_jobs()
|
||||
old_next = (datetime.now() - timedelta(minutes=5)).isoformat()
|
||||
jobs[0]["next_run_at"] = old_next
|
||||
save_jobs(jobs)
|
||||
|
||||
result = advance_next_run(job["id"])
|
||||
assert result is True
|
||||
|
||||
updated = get_job(job["id"])
|
||||
from cron.jobs import _ensure_aware, _hermes_now
|
||||
new_next_dt = _ensure_aware(datetime.fromisoformat(updated["next_run_at"]))
|
||||
assert new_next_dt > _hermes_now(), "next_run_at should be in the future after advance"
|
||||
|
||||
def test_advances_cron_job(self, tmp_cron_dir):
|
||||
"""Cron-expression jobs should have next_run_at bumped to the next occurrence."""
|
||||
pytest.importorskip("croniter")
|
||||
job = create_job(prompt="Daily wakeup", schedule="15 6 * * *")
|
||||
# Force next_run_at to 30 minutes ago
|
||||
jobs = load_jobs()
|
||||
old_next = (datetime.now() - timedelta(minutes=30)).isoformat()
|
||||
jobs[0]["next_run_at"] = old_next
|
||||
save_jobs(jobs)
|
||||
|
||||
result = advance_next_run(job["id"])
|
||||
assert result is True
|
||||
|
||||
updated = get_job(job["id"])
|
||||
from cron.jobs import _ensure_aware, _hermes_now
|
||||
new_next_dt = _ensure_aware(datetime.fromisoformat(updated["next_run_at"]))
|
||||
assert new_next_dt > _hermes_now(), "next_run_at should be in the future after advance"
|
||||
|
||||
def test_skips_oneshot_job(self, tmp_cron_dir):
|
||||
"""One-shot jobs should NOT be advanced — they need to retry on restart."""
|
||||
job = create_job(prompt="Run once", schedule="30m")
|
||||
original_next = get_job(job["id"])["next_run_at"]
|
||||
|
||||
result = advance_next_run(job["id"])
|
||||
assert result is False
|
||||
|
||||
updated = get_job(job["id"])
|
||||
assert updated["next_run_at"] == original_next, "one-shot next_run_at should be unchanged"
|
||||
|
||||
def test_nonexistent_job_returns_false(self, tmp_cron_dir):
|
||||
result = advance_next_run("nonexistent-id")
|
||||
assert result is False
|
||||
|
||||
def test_already_future_stays_future(self, tmp_cron_dir):
|
||||
"""If next_run_at is already in the future, advance keeps it in the future (no harm)."""
|
||||
job = create_job(prompt="Future job", schedule="every 1h")
|
||||
# next_run_at is already set to ~1h from now by create_job
|
||||
advance_next_run(job["id"])
|
||||
# Regardless of return value, the job should still be in the future
|
||||
updated = get_job(job["id"])
|
||||
from cron.jobs import _ensure_aware, _hermes_now
|
||||
new_next_dt = _ensure_aware(datetime.fromisoformat(updated["next_run_at"]))
|
||||
assert new_next_dt > _hermes_now(), "next_run_at should remain in the future"
|
||||
|
||||
def test_crash_safety_scenario(self, tmp_cron_dir):
|
||||
"""Simulate the crash-loop scenario: after advance, the job should NOT be due."""
|
||||
job = create_job(prompt="Crash test", schedule="every 1h")
|
||||
# Force next_run_at to 5 minutes ago (job is due)
|
||||
jobs = load_jobs()
|
||||
jobs[0]["next_run_at"] = (datetime.now() - timedelta(minutes=5)).isoformat()
|
||||
save_jobs(jobs)
|
||||
|
||||
# Job should be due before advance
|
||||
due_before = get_due_jobs()
|
||||
assert len(due_before) == 1
|
||||
|
||||
# Advance (simulating what tick() does before run_job)
|
||||
advance_next_run(job["id"])
|
||||
|
||||
# Now the job should NOT be due (simulates restart after crash)
|
||||
due_after = get_due_jobs()
|
||||
assert len(due_after) == 0, "Job should not be due after advance_next_run"
|
||||
|
||||
|
||||
class TestGetDueJobs:
|
||||
def test_past_due_within_window_returned(self, tmp_cron_dir):
|
||||
"""Jobs within the dynamic grace window are still considered due (not stale).
|
||||
|
||||
@@ -687,3 +687,41 @@ class TestBuildJobPromptMissingSkill:
|
||||
result = _build_job_prompt({"skills": ["ghost-skill", "real-skill"], "prompt": "go"})
|
||||
assert "Real skill content." in result
|
||||
assert "go" in result
|
||||
|
||||
|
||||
class TestTickAdvanceBeforeRun:
|
||||
"""Verify that tick() calls advance_next_run before run_job for crash safety."""
|
||||
|
||||
def test_advance_called_before_run_job(self, tmp_path):
|
||||
"""advance_next_run must be called before run_job to prevent crash-loop re-fires."""
|
||||
call_order = []
|
||||
|
||||
def fake_advance(job_id):
|
||||
call_order.append(("advance", job_id))
|
||||
return True
|
||||
|
||||
def fake_run_job(job):
|
||||
call_order.append(("run", job["id"]))
|
||||
return True, "output", "response", None
|
||||
|
||||
fake_job = {
|
||||
"id": "test-advance",
|
||||
"name": "test",
|
||||
"prompt": "hello",
|
||||
"enabled": True,
|
||||
"schedule": {"kind": "cron", "expr": "15 6 * * *"},
|
||||
}
|
||||
|
||||
with patch("cron.scheduler.get_due_jobs", return_value=[fake_job]), \
|
||||
patch("cron.scheduler.advance_next_run", side_effect=fake_advance) as adv_mock, \
|
||||
patch("cron.scheduler.run_job", side_effect=fake_run_job), \
|
||||
patch("cron.scheduler.save_job_output", return_value=tmp_path / "out.md"), \
|
||||
patch("cron.scheduler.mark_job_run"), \
|
||||
patch("cron.scheduler._deliver_result"):
|
||||
from cron.scheduler import tick
|
||||
executed = tick(verbose=False)
|
||||
|
||||
assert executed == 1
|
||||
adv_mock.assert_called_once_with("test-advance")
|
||||
# advance must happen before run
|
||||
assert call_order == [("advance", "test-advance"), ("run", "test-advance")]
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Tests for the startup allowlist warning check in gateway/run.py."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _would_warn():
|
||||
"""Replicate the startup allowlist warning logic. Returns True if warning fires."""
|
||||
_any_allowlist = any(
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
os.getenv(v, "").lower() in ("true", "1", "yes")
|
||||
for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS",
|
||||
"WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS",
|
||||
"SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS")
|
||||
)
|
||||
return not _any_allowlist and not _allow_all
|
||||
|
||||
|
||||
class TestAllowlistStartupCheck:
|
||||
|
||||
def test_no_config_emits_warning(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert _would_warn() is True
|
||||
|
||||
def test_signal_group_allowed_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"SIGNAL_GROUP_ALLOWED_USERS": "user1"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
|
||||
def test_telegram_allow_all_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"TELEGRAM_ALLOW_ALL_USERS": "true"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
|
||||
def test_gateway_allow_all_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"GATEWAY_ALLOW_ALL_USERS": "yes"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
@@ -69,7 +69,8 @@ class TestApiServerPlatformConfig:
|
||||
|
||||
class TestApiServerAdapterToolset:
|
||||
@patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", True)
|
||||
def test_create_agent_uses_api_server_toolset(self):
|
||||
def test_create_agent_reads_config_toolsets(self):
|
||||
"""API server resolves toolsets from config like all other platforms."""
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
@@ -77,17 +78,52 @@ class TestApiServerAdapterToolset:
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs") as mock_kwargs, \
|
||||
patch("gateway.run._resolve_gateway_model") as mock_model, \
|
||||
patch("gateway.run._load_gateway_config") as mock_config, \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
|
||||
mock_kwargs.return_value = {"api_key": "test-key", "base_url": None,
|
||||
"provider": None, "api_mode": None,
|
||||
"command": None, "args": []}
|
||||
mock_model.return_value = "test/model"
|
||||
# No platform_toolsets override — should fall back to hermes-api-server default
|
||||
mock_config.return_value = {}
|
||||
mock_agent_cls.return_value = MagicMock()
|
||||
|
||||
adapter._create_agent()
|
||||
|
||||
mock_agent_cls.assert_called_once()
|
||||
call_kwargs = mock_agent_cls.call_args
|
||||
assert call_kwargs.kwargs.get("enabled_toolsets") == ["hermes-api-server"]
|
||||
toolsets = call_kwargs.kwargs.get("enabled_toolsets")
|
||||
assert isinstance(toolsets, list)
|
||||
assert len(toolsets) > 0
|
||||
assert call_kwargs.kwargs.get("platform") == "api_server"
|
||||
|
||||
@patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", True)
|
||||
def test_create_agent_respects_config_override(self):
|
||||
"""User can override API server toolsets via platform_toolsets in config.yaml."""
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
adapter = APIServerAdapter(PlatformConfig())
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs") as mock_kwargs, \
|
||||
patch("gateway.run._resolve_gateway_model") as mock_model, \
|
||||
patch("gateway.run._load_gateway_config") as mock_config, \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
|
||||
mock_kwargs.return_value = {"api_key": "test-key", "base_url": None,
|
||||
"provider": None, "api_mode": None,
|
||||
"command": None, "args": []}
|
||||
mock_model.return_value = "test/model"
|
||||
# User overrides with just web and terminal
|
||||
mock_config.return_value = {
|
||||
"platform_toolsets": {"api_server": ["web", "terminal"]}
|
||||
}
|
||||
mock_agent_cls.return_value = MagicMock()
|
||||
|
||||
adapter._create_agent()
|
||||
|
||||
mock_agent_cls.assert_called_once()
|
||||
call_kwargs = mock_agent_cls.call_args
|
||||
toolsets = call_kwargs.kwargs.get("enabled_toolsets")
|
||||
assert sorted(toolsets) == ["terminal", "web"]
|
||||
|
||||
@@ -32,7 +32,7 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
|
||||
@@ -7,11 +7,21 @@ Verifies that:
|
||||
3. The flush still works normally when memory files don't exist
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_dotenv(monkeypatch):
|
||||
"""gateway.run imports dotenv at module level; stub it so tests run without the package."""
|
||||
fake = types.ModuleType("dotenv")
|
||||
fake.load_dotenv = lambda *a, **kw: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake)
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
@@ -57,105 +67,151 @@ class TestCronSessionBypass:
|
||||
runner.session_store.load_transcript.assert_called_once_with("session_abc123")
|
||||
|
||||
|
||||
def _make_flush_context(monkeypatch, memory_dir=None):
|
||||
"""Return (runner, tmp_agent, fake_run_agent) with run_agent mocked in sys.modules."""
|
||||
tmp_agent = MagicMock()
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = MagicMock(return_value=tmp_agent)
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
return runner, tmp_agent, memory_dir
|
||||
|
||||
|
||||
class TestMemoryInjection:
|
||||
"""The flush prompt should include current memory state from disk."""
|
||||
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path):
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path, monkeypatch):
|
||||
"""When memory files exist, their content appears in the flush prompt."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode")
|
||||
(memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST")
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch, memory_dir)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Intercept `from tools.memory_tool import MEMORY_DIR` inside the function
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_123")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
call_kwargs = tmp_agent.run_conversation.call_args.kwargs
|
||||
flush_prompt = call_kwargs.get("user_message", "")
|
||||
|
||||
# Verify both memory sections appear in the prompt
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
|
||||
assert "Agent knows Python" in flush_prompt
|
||||
assert "User prefers dark mode" in flush_prompt
|
||||
assert "Name: Alice" in flush_prompt
|
||||
assert "Timezone: PST" in flush_prompt
|
||||
# Verify the stale-overwrite warning is present
|
||||
assert "Do NOT overwrite or remove entries" in flush_prompt
|
||||
assert "current live state of memory" in flush_prompt
|
||||
|
||||
def test_flush_works_without_memory_files(self, tmp_path):
|
||||
def test_flush_works_without_memory_files(self, tmp_path, monkeypatch):
|
||||
"""When no memory files exist, flush still runs without the guard."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
empty_dir = tmp_path / "no_memories"
|
||||
empty_dir.mkdir()
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_456")
|
||||
|
||||
# Should still run, just without the memory guard section
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "Do NOT overwrite or remove entries" not in flush_prompt
|
||||
assert "Review the conversation above" in flush_prompt
|
||||
|
||||
def test_empty_memory_files_no_injection(self, tmp_path):
|
||||
def test_empty_memory_files_no_injection(self, tmp_path, monkeypatch):
|
||||
"""Empty memory files should not trigger the guard section."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("")
|
||||
(memory_dir / "USER.md").write_text(" \n ") # whitespace only
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_789")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
# No memory content → no guard section
|
||||
assert "current live state of memory" not in flush_prompt
|
||||
|
||||
|
||||
class TestFlushAgentSilenced:
|
||||
"""The flush agent must not produce any terminal output."""
|
||||
|
||||
def test_print_fn_set_to_noop(self, tmp_path, monkeypatch):
|
||||
"""_print_fn on the flush agent must be a no-op so tool output never leaks."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
captured_agent = {}
|
||||
|
||||
def _fake_ai_agent(*args, **kwargs):
|
||||
agent = MagicMock()
|
||||
captured_agent["instance"] = agent
|
||||
return agent
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _fake_ai_agent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=tmp_path)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_silent")
|
||||
|
||||
agent = captured_agent["instance"]
|
||||
assert agent._print_fn is not None, "_print_fn should be overridden to suppress output"
|
||||
# Confirm it is callable and produces no output (no exception)
|
||||
agent._print_fn("should be silenced")
|
||||
|
||||
def test_kawaii_spinner_respects_print_fn(self):
|
||||
"""KawaiiSpinner must route all output through print_fn when supplied."""
|
||||
from agent.display import KawaiiSpinner
|
||||
|
||||
written = []
|
||||
spinner = KawaiiSpinner("test", print_fn=lambda *a, **kw: written.append(a))
|
||||
spinner._write("hello")
|
||||
assert written == [("hello",)], "spinner should route through print_fn"
|
||||
|
||||
# A no-op print_fn must produce no output to stdout
|
||||
import io, sys
|
||||
buf = io.StringIO()
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = buf
|
||||
try:
|
||||
silent_spinner = KawaiiSpinner("silent", print_fn=lambda *a, **kw: None)
|
||||
silent_spinner._write("should not appear")
|
||||
silent_spinner.stop("done")
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
assert buf.getvalue() == "", "no-op print_fn spinner must not write to stdout"
|
||||
|
||||
|
||||
class TestFlushPromptStructure:
|
||||
"""Verify the flush prompt retains its core instructions."""
|
||||
|
||||
def test_core_instructions_present(self):
|
||||
def test_core_instructions_present(self, monkeypatch):
|
||||
"""The flush prompt should still contain the original guidance."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Make the import fail gracefully so we test without memory files
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_struct")
|
||||
|
||||
@@ -0,0 +1,558 @@
|
||||
"""
|
||||
Tests for media download retry logic added in PR #2982.
|
||||
|
||||
Covers:
|
||||
- gateway/platforms/base.py: cache_image_from_url
|
||||
- gateway/platforms/slack.py: SlackAdapter._download_slack_file
|
||||
SlackAdapter._download_slack_file_bytes
|
||||
- gateway/platforms/mattermost.py: MattermostAdapter._send_url_as_file
|
||||
|
||||
All async tests use asyncio.run() directly — pytest-asyncio is not installed
|
||||
in this environment.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for building httpx exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_http_status_error(status_code: int) -> httpx.HTTPStatusError:
|
||||
request = httpx.Request("GET", "http://example.com/img.jpg")
|
||||
response = httpx.Response(status_code=status_code, request=request)
|
||||
return httpx.HTTPStatusError(
|
||||
f"HTTP {status_code}", request=request, response=response
|
||||
)
|
||||
|
||||
|
||||
def _make_timeout_error() -> httpx.TimeoutException:
|
||||
return httpx.TimeoutException("timed out")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cache_image_from_url (base.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheImageFromUrl:
|
||||
"""Tests for gateway.platforms.base.cache_image_from_url"""
|
||||
|
||||
def test_success_on_first_attempt(self, tmp_path, monkeypatch):
|
||||
"""A clean 200 response caches the image and returns a path."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"\xff\xd8\xff fake jpeg"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
return await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A timeout on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"image data"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_timeout_error(), fake_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
return await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A 429 response on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"image data"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_http_status_error(429), ok_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
return await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries_exhausted(self, tmp_path, monkeypatch):
|
||||
"""Timeout on every attempt raises after all retries are consumed."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
asyncio.run(run())
|
||||
|
||||
# 3 total calls: initial + 2 retries
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
def test_non_retryable_4xx_raises_immediately(self, tmp_path, monkeypatch):
|
||||
"""A 404 (non-retryable) is raised immediately without any retry."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_http_status_error(404))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
asyncio.run(run())
|
||||
|
||||
# Only 1 attempt, no sleep
|
||||
assert mock_client.get.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slack mock setup (mirrors existing test_slack.py approach)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_slack_mock():
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return
|
||||
slack_bolt = MagicMock()
|
||||
slack_bolt.async_app.AsyncApp = MagicMock
|
||||
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
|
||||
slack_sdk = MagicMock()
|
||||
slack_sdk.web.async_client.AsyncWebClient = MagicMock
|
||||
for name, mod in [
|
||||
("slack_bolt", slack_bolt),
|
||||
("slack_bolt.async_app", slack_bolt.async_app),
|
||||
("slack_bolt.adapter", slack_bolt.adapter),
|
||||
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
|
||||
("slack_bolt.adapter.socket_mode.async_handler",
|
||||
slack_bolt.adapter.socket_mode.async_handler),
|
||||
("slack_sdk", slack_sdk),
|
||||
("slack_sdk.web", slack_sdk.web),
|
||||
("slack_sdk.web.async_client", slack_sdk.web.async_client),
|
||||
]:
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_ensure_slack_mock()
|
||||
|
||||
import gateway.platforms.slack as _slack_mod # noqa: E402
|
||||
_slack_mod.SLACK_AVAILABLE = True
|
||||
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
from gateway.config import Platform, PlatformConfig # noqa: E402
|
||||
|
||||
|
||||
def _make_slack_adapter():
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake-token")
|
||||
adapter = SlackAdapter(config)
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.client = AsyncMock()
|
||||
adapter._bot_user_id = "U_BOT"
|
||||
adapter._running = True
|
||||
return adapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SlackAdapter._download_slack_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSlackDownloadSlackFile:
|
||||
"""Tests for SlackAdapter._download_slack_file"""
|
||||
|
||||
def test_success_on_first_attempt(self, tmp_path, monkeypatch):
|
||||
"""Successful download on first try returns a cached file path."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"fake image bytes"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
return await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""Timeout on first attempt triggers retry; success on second."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"image bytes"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_timeout_error(), fake_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
return await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_raises_after_max_retries(self, tmp_path, monkeypatch):
|
||||
"""Timeout on every attempt eventually raises after 3 total tries."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
asyncio.run(run())
|
||||
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
def test_non_retryable_403_raises_immediately(self, tmp_path, monkeypatch):
|
||||
"""A 403 is not retried; it raises immediately."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_http_status_error(403))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
asyncio.run(run())
|
||||
|
||||
assert mock_client.get.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SlackAdapter._download_slack_file_bytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSlackDownloadSlackFileBytes:
|
||||
"""Tests for SlackAdapter._download_slack_file_bytes"""
|
||||
|
||||
def test_success_returns_bytes(self):
|
||||
"""Successful download returns raw bytes."""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"raw bytes here"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
return await adapter._download_slack_file_bytes(
|
||||
"https://files.slack.com/file.bin"
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result == b"raw bytes here"
|
||||
|
||||
def test_retries_on_429_then_succeeds(self):
|
||||
"""429 on first attempt is retried; raw bytes returned on second."""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"final bytes"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_http_status_error(429), ok_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._download_slack_file_bytes(
|
||||
"https://files.slack.com/file.bin"
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result == b"final bytes"
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries(self):
|
||||
"""Persistent timeouts raise after all 3 attempts are exhausted."""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await adapter._download_slack_file_bytes(
|
||||
"https://files.slack.com/file.bin"
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
asyncio.run(run())
|
||||
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MattermostAdapter._send_url_as_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mm_adapter():
|
||||
"""Build a minimal MattermostAdapter with mocked internals."""
|
||||
from gateway.platforms.mattermost import MattermostAdapter
|
||||
config = PlatformConfig(
|
||||
enabled=True, token="mm-token-fake",
|
||||
extra={"url": "https://mm.example.com"},
|
||||
)
|
||||
adapter = MattermostAdapter(config)
|
||||
adapter._session = MagicMock()
|
||||
adapter._upload_file = AsyncMock(return_value="file-id-123")
|
||||
adapter._api_post = AsyncMock(return_value={"id": "post-id-abc"})
|
||||
adapter.send = AsyncMock(return_value=MagicMock(success=True))
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_aiohttp_resp(status: int, content: bytes = b"file bytes",
|
||||
content_type: str = "image/jpeg"):
|
||||
"""Build a context-manager mock for an aiohttp response."""
|
||||
resp = MagicMock()
|
||||
resp.status = status
|
||||
resp.content_type = content_type
|
||||
resp.read = AsyncMock(return_value=content)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return resp
|
||||
|
||||
|
||||
class TestMattermostSendUrlAsFile:
|
||||
"""Tests for MattermostAdapter._send_url_as_file"""
|
||||
|
||||
def test_success_on_first_attempt(self):
|
||||
"""200 on first attempt → file uploaded and post created."""
|
||||
adapter = _make_mm_adapter()
|
||||
resp = _make_aiohttp_resp(200)
|
||||
adapter._session.get = MagicMock(return_value=resp)
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", "caption", None
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result.success
|
||||
adapter._upload_file.assert_called_once()
|
||||
adapter._api_post.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self):
|
||||
"""429 on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_429 = _make_aiohttp_resp(429)
|
||||
resp_200 = _make_aiohttp_resp(200)
|
||||
adapter._session.get = MagicMock(side_effect=[resp_429, resp_200])
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", mock_sleep):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result.success
|
||||
assert adapter._session.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_500_then_succeeds(self):
|
||||
"""5xx on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_500 = _make_aiohttp_resp(500)
|
||||
resp_200 = _make_aiohttp_resp(200)
|
||||
adapter._session.get = MagicMock(side_effect=[resp_500, resp_200])
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result.success
|
||||
assert adapter._session.get.call_count == 2
|
||||
|
||||
def test_falls_back_to_text_after_max_retries_on_5xx(self):
|
||||
"""Three consecutive 500s exhaust retries; falls back to send() with URL text."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_500 = _make_aiohttp_resp(500)
|
||||
adapter._session.get = MagicMock(return_value=resp_500)
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", "my caption", None
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
adapter.send.assert_called_once()
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_falls_back_on_client_error(self):
|
||||
"""aiohttp.ClientError on every attempt falls back to send() with URL."""
|
||||
import aiohttp
|
||||
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
error_resp = MagicMock()
|
||||
error_resp.__aenter__ = AsyncMock(
|
||||
side_effect=aiohttp.ClientConnectionError("connection refused")
|
||||
)
|
||||
error_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
adapter._session.get = MagicMock(return_value=error_resp)
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
adapter.send.assert_called_once()
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_non_retryable_404_falls_back_immediately(self):
|
||||
"""404 is non-retryable (< 500, != 429); send() is called right away."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_404 = _make_aiohttp_resp(404)
|
||||
adapter._session.get = MagicMock(return_value=resp_404)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", mock_sleep):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
adapter.send.assert_called_once()
|
||||
# No sleep — fell back on first attempt
|
||||
mock_sleep.assert_not_called()
|
||||
assert adapter._session.get.call_count == 1
|
||||
@@ -76,7 +76,7 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Tests for BasePlatformAdapter._send_with_retry and _is_retryable_error.
|
||||
|
||||
Verifies that:
|
||||
- Transient network errors trigger retry with backoff
|
||||
- Permanent errors fall back to plain-text immediately (no retry)
|
||||
- User receives a delivery-failure notice when all retries are exhausted
|
||||
- Successful sends on retry return success
|
||||
- SendResult.retryable flag is respected
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult, _RETRYABLE_ERROR_PATTERNS
|
||||
from gateway.platforms.base import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal concrete adapter for testing (no real network)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
cfg = PlatformConfig()
|
||||
super().__init__(cfg, Platform.TELEGRAM)
|
||||
self._send_results = [] # queue of SendResult to return per call
|
||||
self._send_calls = [] # record of (chat_id, content) sent
|
||||
|
||||
def _next_result(self) -> SendResult:
|
||||
if self._send_results:
|
||||
return self._send_results.pop(0)
|
||||
return SendResult(success=True, message_id="ok")
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None, **kwargs) -> SendResult:
|
||||
self._send_calls.append((chat_id, content))
|
||||
return self._next_result()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
pass
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None) -> None:
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"name": "test", "type": "direct", "chat_id": chat_id}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_retryable_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsRetryableError:
|
||||
def test_none_is_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error(None)
|
||||
|
||||
def test_empty_string_is_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("")
|
||||
|
||||
@pytest.mark.parametrize("pattern", _RETRYABLE_ERROR_PATTERNS)
|
||||
def test_known_pattern_is_retryable(self, pattern):
|
||||
assert _StubAdapter._is_retryable_error(f"httpx.{pattern.title()}: connection dropped")
|
||||
|
||||
def test_permission_error_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("Forbidden: bot was blocked by the user")
|
||||
|
||||
def test_bad_request_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("Bad Request: can't parse entities")
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _StubAdapter._is_retryable_error("CONNECTERROR: host unreachable")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — success on first attempt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetrySuccess:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_first_attempt(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [SendResult(success=True, message_id="123")]
|
||||
result = await adapter._send_with_retry("chat1", "hello")
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_message_id(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [SendResult(success=True, message_id="abc")]
|
||||
result = await adapter._send_with_retry("chat1", "hi")
|
||||
assert result.message_id == "abc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — network error with successful retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryNetworkRetry:
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_connect_error_and_succeeds(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="httpx.ConnectError: connection refused"),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2 # initial + 1 retry
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_timeout_and_succeeds(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="ReadTimeout: request timed out"),
|
||||
SendResult(success=False, error="ReadTimeout: request timed out"),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=3, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_flag_respected(self):
|
||||
"""SendResult.retryable=True should trigger retry even if error string doesn't match."""
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="internal platform error", retryable=True),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_to_nonnetwork_transition_falls_back_to_plaintext(self):
|
||||
"""If error switches from network to formatting mid-retry, fall through to plain-text fallback."""
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="httpx.ConnectError: host unreachable"),
|
||||
SendResult(success=False, error="Bad Request: can't parse entities"),
|
||||
SendResult(success=True, message_id="fallback_ok"), # plain-text fallback
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "**bold**", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
# 3 calls: initial (network) + 1 retry (non-network, breaks loop) + plain-text fallback
|
||||
assert len(adapter._send_calls) == 3
|
||||
assert "plain text" in adapter._send_calls[-1][1].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — all retries exhausted → user notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryExhausted:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_user_notice_after_exhaustion(self):
|
||||
adapter = _StubAdapter()
|
||||
network_err = SendResult(success=False, error="httpx.ConnectError: host unreachable")
|
||||
# initial + 2 retries + notice attempt
|
||||
adapter._send_results = [network_err, network_err, network_err, SendResult(success=True)]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
# Result is the last failed one (before notice)
|
||||
assert not result.success
|
||||
# 4 total calls: 1 initial + 2 retries + 1 notice
|
||||
assert len(adapter._send_calls) == 4
|
||||
# The notice content should mention delivery failure
|
||||
notice_content = adapter._send_calls[-1][1]
|
||||
assert "delivery failed" in notice_content.lower() or "Message delivery failed" in notice_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notice_send_exception_doesnt_propagate(self):
|
||||
"""If the notice itself throws, _send_with_retry should not raise."""
|
||||
adapter = _StubAdapter()
|
||||
network_err = SendResult(success=False, error="ConnectError")
|
||||
adapter._send_results = [network_err, network_err, network_err]
|
||||
|
||||
original_send = adapter.send
|
||||
call_count = [0]
|
||||
|
||||
async def send_with_notice_failure(chat_id, content, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] > 3:
|
||||
raise RuntimeError("notice send also failed")
|
||||
return network_err
|
||||
|
||||
adapter.send = send_with_notice_failure
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert not result.success # still failed, but no exception raised
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — non-network failure → plain-text fallback (no retry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryFallback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_network_error_falls_back_immediately(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="Bad Request: can't parse entities"),
|
||||
SendResult(success=True, message_id="fallback_ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
result = await adapter._send_with_retry("chat1", "**bold**", max_retries=2, base_delay=0)
|
||||
# No sleep — no retry loop for non-network errors
|
||||
mock_sleep.assert_not_called()
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2
|
||||
# Fallback content should be plain-text notice
|
||||
assert "plain text" in adapter._send_calls[1][1].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_failure_logged_but_not_raised(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="Forbidden: bot blocked"),
|
||||
SendResult(success=False, error="Forbidden: bot blocked"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2)
|
||||
assert not result.success
|
||||
assert len(adapter._send_calls) == 2 # original + fallback only
|
||||
@@ -846,7 +846,7 @@ class TestLastPromptTokens:
|
||||
|
||||
store.update_session("k1", model="openai/gpt-5.4")
|
||||
|
||||
store._db.update_token_counts.assert_called_once_with(
|
||||
store._db.set_token_counts.assert_called_once_with(
|
||||
"s1",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
@@ -858,4 +858,48 @@ class TestLastPromptTokens:
|
||||
billing_provider=None,
|
||||
billing_base_url=None,
|
||||
model="openai/gpt-5.4",
|
||||
absolute=True,
|
||||
)
|
||||
|
||||
|
||||
class TestRewriteTranscriptPreservesReasoning:
|
||||
"""rewrite_transcript must not drop reasoning fields from SQLite."""
|
||||
|
||||
def test_reasoning_survives_rewrite(self, tmp_path):
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "test.db")
|
||||
session_id = "reasoning-test"
|
||||
db.create_session(session_id=session_id, source="cli")
|
||||
|
||||
# Insert a message WITH all three reasoning fields
|
||||
db.append_message(
|
||||
session_id=session_id,
|
||||
role="assistant",
|
||||
content="The answer is 42.",
|
||||
reasoning="I need to think step by step.",
|
||||
reasoning_details=[{"type": "summary", "text": "step by step"}],
|
||||
codex_reasoning_items=[{"id": "r1", "type": "reasoning"}],
|
||||
)
|
||||
|
||||
# Verify all three were stored
|
||||
before = db.get_messages_as_conversation(session_id)
|
||||
assert before[0].get("reasoning") == "I need to think step by step."
|
||||
assert before[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
|
||||
assert before[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
|
||||
|
||||
# Now simulate /retry: build the SessionStore and call rewrite_transcript
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = db
|
||||
store._loaded = True
|
||||
|
||||
# rewrite_transcript receives the messages that load_transcript returned
|
||||
store.rewrite_transcript(session_id, before)
|
||||
|
||||
# Load again — all three reasoning fields must survive
|
||||
after = db.get_messages_as_conversation(session_id)
|
||||
assert after[0].get("reasoning") == "I need to think step by step."
|
||||
assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
|
||||
assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Tests for GatewayRunner._format_session_info — session config surfacing."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from pathlib import Path
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def runner():
|
||||
"""Create a bare GatewayRunner without __init__."""
|
||||
return GatewayRunner.__new__(GatewayRunner)
|
||||
|
||||
|
||||
def _patch_info(tmp_path, config_yaml, model, runtime):
|
||||
"""Return a context-manager stack that patches _format_session_info deps."""
|
||||
cfg_path = tmp_path / "config.yaml"
|
||||
if config_yaml is not None:
|
||||
cfg_path.write_text(config_yaml)
|
||||
return (
|
||||
patch("gateway.run._hermes_home", tmp_path),
|
||||
patch("gateway.run._resolve_gateway_model", return_value=model),
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value=runtime),
|
||||
)
|
||||
|
||||
|
||||
class TestFormatSessionInfo:
|
||||
|
||||
def test_includes_model_name(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: anthropic/claude-opus-4.6\n provider: openrouter\n",
|
||||
"anthropic/claude-opus-4.6",
|
||||
{"provider": "openrouter", "base_url": "https://openrouter.ai/api/v1", "api_key": "k"})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "claude-opus-4.6" in info
|
||||
|
||||
def test_includes_provider(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n provider: openrouter\n",
|
||||
"test-model",
|
||||
{"provider": "openrouter", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "openrouter" in info
|
||||
|
||||
def test_config_context_length(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n context_length: 32768\n",
|
||||
"test-model",
|
||||
{"provider": "custom", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "32K" in info
|
||||
assert "config" in info
|
||||
|
||||
def test_default_fallback_hint(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: unknown-model-xyz\n",
|
||||
"unknown-model-xyz",
|
||||
{"provider": "", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "128K" in info
|
||||
assert "model.context_length" in info
|
||||
|
||||
def test_local_endpoint_shown(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(
|
||||
tmp_path,
|
||||
"model:\n default: qwen3:8b\n provider: custom\n base_url: http://localhost:11434/v1\n context_length: 8192\n",
|
||||
"qwen3:8b",
|
||||
{"provider": "custom", "base_url": "http://localhost:11434/v1", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "localhost:11434" in info
|
||||
assert "8K" in info
|
||||
|
||||
def test_cloud_endpoint_hidden(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n provider: openrouter\n",
|
||||
"test-model",
|
||||
{"provider": "openrouter", "base_url": "https://openrouter.ai/api/v1", "api_key": "k"})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "Endpoint" not in info
|
||||
|
||||
def test_million_context_format(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n context_length: 1000000\n",
|
||||
"test-model",
|
||||
{"provider": "", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "1.0M" in info
|
||||
|
||||
def test_missing_config(self, runner, tmp_path):
|
||||
"""No config.yaml should not crash."""
|
||||
p1, p2, p3 = _patch_info(tmp_path, None, # don't create config
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
{"provider": "openrouter", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "Model" in info
|
||||
assert "Context" in info
|
||||
|
||||
def test_runtime_resolution_failure_doesnt_crash(self, runner, tmp_path):
|
||||
"""If runtime resolution raises, should still produce output."""
|
||||
cfg_path = tmp_path / "config.yaml"
|
||||
cfg_path.write_text("model:\n default: test-model\n context_length: 4096\n")
|
||||
with patch("gateway.run._hermes_home", tmp_path), \
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"), \
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", side_effect=RuntimeError("no creds")):
|
||||
info = runner._format_session_info()
|
||||
assert "4K" in info
|
||||
assert "config" in info
|
||||
@@ -20,7 +20,7 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
@@ -29,6 +29,14 @@ _ensure_telegram_mock()
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_auto_discovery(monkeypatch):
|
||||
"""Disable DoH auto-discovery so connect() uses the plain builder chain."""
|
||||
async def _noop():
|
||||
return []
|
||||
monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_rejects_same_host_token_lock(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="secret-token"))
|
||||
|
||||
@@ -45,7 +45,7 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ def _ensure_telegram_mock():
|
||||
mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
mod.constants.ChatType.CHANNEL = "channel"
|
||||
mod.constants.ChatType.PRIVATE = "private"
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,626 @@
|
||||
"""Tests for gateway.platforms.telegram_network – fallback transport layer.
|
||||
|
||||
Background
|
||||
----------
|
||||
api.telegram.org resolves to an IP (e.g. 149.154.166.110) that is unreachable
|
||||
from some networks. The workaround: route TCP through a different IP in the
|
||||
same Telegram-owned 149.154.160.0/20 block (e.g. 149.154.167.220) while
|
||||
keeping TLS SNI and the Host header as api.telegram.org so Telegram's edge
|
||||
servers still accept the request. This is the programmatic equivalent of:
|
||||
|
||||
curl --resolve api.telegram.org:443:149.154.167.220 https://api.telegram.org/bot<token>/getMe
|
||||
|
||||
The TelegramFallbackTransport implements this: try the primary (DNS-resolved)
|
||||
path first, and on ConnectTimeout / ConnectError fall through to configured
|
||||
fallback IPs in order, then "stick" to whichever IP works.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from gateway.platforms import telegram_network as tnet
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class FakeTransport(httpx.AsyncBaseTransport):
|
||||
"""Records calls and raises / returns based on a host→action mapping."""
|
||||
|
||||
def __init__(self, calls, behavior):
|
||||
self.calls = calls
|
||||
self.behavior = behavior
|
||||
self.closed = False
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
self.calls.append(
|
||||
{
|
||||
"url_host": request.url.host,
|
||||
"host_header": request.headers.get("host"),
|
||||
"sni_hostname": request.extensions.get("sni_hostname"),
|
||||
"path": request.url.path,
|
||||
}
|
||||
)
|
||||
action = self.behavior.get(request.url.host, "ok")
|
||||
if action == "timeout":
|
||||
raise httpx.ConnectTimeout("timed out")
|
||||
if action == "connect_error":
|
||||
raise httpx.ConnectError("connect error")
|
||||
if isinstance(action, Exception):
|
||||
raise action
|
||||
return httpx.Response(200, request=request, text="ok")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
def _fake_transport_factory(calls, behavior):
|
||||
"""Returns a factory that creates FakeTransport instances."""
|
||||
instances = []
|
||||
|
||||
def factory(**kwargs):
|
||||
t = FakeTransport(calls, behavior)
|
||||
instances.append(t)
|
||||
return t
|
||||
|
||||
factory.instances = instances
|
||||
return factory
|
||||
|
||||
|
||||
def _telegram_request(path="/botTOKEN/getMe"):
|
||||
return httpx.Request("GET", f"https://api.telegram.org{path}")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# IP parsing & validation
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestParseFallbackIpEnv:
|
||||
def test_filters_invalid_and_ipv6(self, caplog):
|
||||
ips = tnet.parse_fallback_ip_env("149.154.167.220, bad, 2001:67c:4e8:f004::9,149.154.167.220")
|
||||
assert ips == ["149.154.167.220", "149.154.167.220"]
|
||||
assert "Ignoring invalid Telegram fallback IP" in caplog.text
|
||||
assert "Ignoring non-IPv4 Telegram fallback IP" in caplog.text
|
||||
|
||||
def test_none_returns_empty(self):
|
||||
assert tnet.parse_fallback_ip_env(None) == []
|
||||
|
||||
def test_empty_string_returns_empty(self):
|
||||
assert tnet.parse_fallback_ip_env("") == []
|
||||
|
||||
def test_whitespace_only_returns_empty(self):
|
||||
assert tnet.parse_fallback_ip_env(" , , ") == []
|
||||
|
||||
def test_single_valid_ip(self):
|
||||
assert tnet.parse_fallback_ip_env("149.154.167.220") == ["149.154.167.220"]
|
||||
|
||||
def test_multiple_valid_ips(self):
|
||||
ips = tnet.parse_fallback_ip_env("149.154.167.220, 149.154.167.221")
|
||||
assert ips == ["149.154.167.220", "149.154.167.221"]
|
||||
|
||||
def test_rejects_leading_zeros(self, caplog):
|
||||
"""Leading zeros are ambiguous (octal?) so ipaddress rejects them."""
|
||||
ips = tnet.parse_fallback_ip_env("149.154.167.010")
|
||||
assert ips == []
|
||||
assert "Ignoring invalid" in caplog.text
|
||||
|
||||
|
||||
class TestNormalizeFallbackIps:
|
||||
def test_deduplication_happens_at_transport_level(self):
|
||||
"""_normalize does not dedup; TelegramFallbackTransport.__init__ does."""
|
||||
raw = ["149.154.167.220", "149.154.167.220"]
|
||||
assert tnet._normalize_fallback_ips(raw) == ["149.154.167.220", "149.154.167.220"]
|
||||
|
||||
def test_empty_strings_skipped(self):
|
||||
assert tnet._normalize_fallback_ips(["", " ", "149.154.167.220"]) == ["149.154.167.220"]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Request rewriting
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestRewriteRequestForIp:
|
||||
def test_preserves_host_and_sni(self):
|
||||
request = _telegram_request()
|
||||
rewritten = tnet._rewrite_request_for_ip(request, "149.154.167.220")
|
||||
|
||||
assert rewritten.url.host == "149.154.167.220"
|
||||
assert rewritten.headers["host"] == "api.telegram.org"
|
||||
assert rewritten.extensions["sni_hostname"] == "api.telegram.org"
|
||||
assert rewritten.url.path == "/botTOKEN/getMe"
|
||||
|
||||
def test_preserves_method_and_path(self):
|
||||
request = httpx.Request("POST", "https://api.telegram.org/botTOKEN/sendMessage")
|
||||
rewritten = tnet._rewrite_request_for_ip(request, "149.154.167.220")
|
||||
|
||||
assert rewritten.method == "POST"
|
||||
assert rewritten.url.path == "/botTOKEN/sendMessage"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Fallback transport – core behavior
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestFallbackTransport:
|
||||
"""Primary path fails → try fallback IPs → stick to whichever works."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_connect_timeout_and_becomes_sticky(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {"api.telegram.org": "timeout", "149.154.167.220": "ok"}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert transport._sticky_ip == "149.154.167.220"
|
||||
# First attempt was primary (api.telegram.org), second was fallback
|
||||
assert calls[0]["url_host"] == "api.telegram.org"
|
||||
assert calls[1]["url_host"] == "149.154.167.220"
|
||||
assert calls[1]["host_header"] == "api.telegram.org"
|
||||
assert calls[1]["sni_hostname"] == "api.telegram.org"
|
||||
|
||||
# Second request goes straight to sticky IP
|
||||
calls.clear()
|
||||
resp2 = await transport.handle_async_request(_telegram_request())
|
||||
assert resp2.status_code == 200
|
||||
assert calls[0]["url_host"] == "149.154.167.220"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_connect_error(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {"api.telegram.org": "connect_error", "149.154.167.220": "ok"}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert transport._sticky_ip == "149.154.167.220"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_fallback_on_non_connect_error(self, monkeypatch):
|
||||
"""Errors like ReadTimeout are not connection issues — don't retry."""
|
||||
calls = []
|
||||
behavior = {"api.telegram.org": httpx.ReadTimeout("read timeout"), "149.154.167.220": "ok"}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
|
||||
with pytest.raises(httpx.ReadTimeout):
|
||||
await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert [c["url_host"] for c in calls] == ["api.telegram.org"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_ips_fail_raises_last_error(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {"api.telegram.org": "timeout", "149.154.167.220": "timeout"}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
|
||||
with pytest.raises(httpx.ConnectTimeout):
|
||||
await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert [c["url_host"] for c in calls] == ["api.telegram.org", "149.154.167.220"]
|
||||
assert transport._sticky_ip is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_fallback_ips_tried_in_order(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {
|
||||
"api.telegram.org": "timeout",
|
||||
"149.154.167.220": "timeout",
|
||||
"149.154.167.221": "ok",
|
||||
}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.221"])
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert transport._sticky_ip == "149.154.167.221"
|
||||
assert [c["url_host"] for c in calls] == [
|
||||
"api.telegram.org",
|
||||
"149.154.167.220",
|
||||
"149.154.167.221",
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sticky_ip_tried_first_but_falls_through_if_stale(self, monkeypatch):
|
||||
"""If the sticky IP stops working, the transport retries others."""
|
||||
calls = []
|
||||
behavior = {
|
||||
"api.telegram.org": "timeout",
|
||||
"149.154.167.220": "ok",
|
||||
"149.154.167.221": "ok",
|
||||
}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.221"])
|
||||
|
||||
# First request: primary fails → .220 works → becomes sticky
|
||||
await transport.handle_async_request(_telegram_request())
|
||||
assert transport._sticky_ip == "149.154.167.220"
|
||||
|
||||
# Now .220 goes bad too
|
||||
calls.clear()
|
||||
behavior["149.154.167.220"] = "timeout"
|
||||
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
assert resp.status_code == 200
|
||||
# Tried sticky (.220) first, then fell through to .221
|
||||
assert [c["url_host"] for c in calls] == ["149.154.167.220", "149.154.167.221"]
|
||||
assert transport._sticky_ip == "149.154.167.221"
|
||||
|
||||
|
||||
class TestFallbackTransportPassthrough:
|
||||
"""Requests that don't need fallback behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_telegram_host_bypasses_fallback(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
request = httpx.Request("GET", "https://example.com/path")
|
||||
resp = await transport.handle_async_request(request)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert calls[0]["url_host"] == "example.com"
|
||||
assert transport._sticky_ip is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_fallback_list_uses_primary_only(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport([])
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert calls[0]["url_host"] == "api.telegram.org"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_primary_succeeds_no_fallback_needed(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {"api.telegram.org": "ok"}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert transport._sticky_ip is None
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
class TestFallbackTransportInit:
|
||||
def test_deduplicates_fallback_ips(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
tnet.httpx, "AsyncHTTPTransport", lambda **kw: FakeTransport([], {})
|
||||
)
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.220"])
|
||||
assert transport._fallback_ips == ["149.154.167.220"]
|
||||
|
||||
def test_filters_invalid_ips_at_init(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
tnet.httpx, "AsyncHTTPTransport", lambda **kw: FakeTransport([], {})
|
||||
)
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220", "not-an-ip"])
|
||||
assert transport._fallback_ips == ["149.154.167.220"]
|
||||
|
||||
|
||||
class TestFallbackTransportClose:
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_closes_all_transports(self, monkeypatch):
|
||||
factory = _fake_transport_factory([], {})
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", factory)
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.221"])
|
||||
await transport.aclose()
|
||||
|
||||
# 1 primary + 2 fallback transports
|
||||
assert len(factory.instances) == 3
|
||||
assert all(t.closed for t in factory.instances)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Config layer – TELEGRAM_FALLBACK_IPS env → config.extra
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestConfigFallbackIps:
|
||||
def test_env_var_populates_config_extra(self, monkeypatch):
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides
|
||||
|
||||
monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", "149.154.167.220,149.154.167.221")
|
||||
config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")})
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert config.platforms[Platform.TELEGRAM].extra["fallback_ips"] == [
|
||||
"149.154.167.220", "149.154.167.221",
|
||||
]
|
||||
|
||||
def test_env_var_creates_platform_if_missing(self, monkeypatch):
|
||||
from gateway.config import GatewayConfig, Platform, _apply_env_overrides
|
||||
|
||||
monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", "149.154.167.220")
|
||||
config = GatewayConfig(platforms={})
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.TELEGRAM in config.platforms
|
||||
assert config.platforms[Platform.TELEGRAM].extra["fallback_ips"] == ["149.154.167.220"]
|
||||
|
||||
def test_env_var_strips_whitespace(self, monkeypatch):
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides
|
||||
|
||||
monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", " 149.154.167.220 , 149.154.167.221 ")
|
||||
config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")})
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert config.platforms[Platform.TELEGRAM].extra["fallback_ips"] == [
|
||||
"149.154.167.220", "149.154.167.221",
|
||||
]
|
||||
|
||||
def test_empty_env_var_does_not_populate(self, monkeypatch):
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides
|
||||
|
||||
monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", "")
|
||||
config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")})
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert "fallback_ips" not in config.platforms[Platform.TELEGRAM].extra
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Adapter layer – _fallback_ips() reads config correctly
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestAdapterFallbackIps:
|
||||
def _make_adapter(self, extra=None):
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Ensure telegram mock is in place
|
||||
if "telegram" not in sys.modules or not hasattr(sys.modules["telegram"], "__file__"):
|
||||
mod = MagicMock()
|
||||
mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
mod.constants.ChatType.GROUP = "group"
|
||||
mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
mod.constants.ChatType.CHANNEL = "channel"
|
||||
mod.constants.ChatType.PRIVATE = "private"
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
if extra:
|
||||
config.extra.update(extra)
|
||||
return TelegramAdapter(config)
|
||||
|
||||
def test_list_in_extra(self):
|
||||
adapter = self._make_adapter(extra={"fallback_ips": ["149.154.167.220"]})
|
||||
assert adapter._fallback_ips() == ["149.154.167.220"]
|
||||
|
||||
def test_csv_string_in_extra(self):
|
||||
adapter = self._make_adapter(extra={"fallback_ips": "149.154.167.220,149.154.167.221"})
|
||||
assert adapter._fallback_ips() == ["149.154.167.220", "149.154.167.221"]
|
||||
|
||||
def test_empty_extra(self):
|
||||
adapter = self._make_adapter()
|
||||
assert adapter._fallback_ips() == []
|
||||
|
||||
def test_no_extra_attr(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter.config.extra = None
|
||||
assert adapter._fallback_ips() == []
|
||||
|
||||
def test_invalid_ips_filtered(self):
|
||||
adapter = self._make_adapter(extra={"fallback_ips": ["149.154.167.220", "not-valid"]})
|
||||
assert adapter._fallback_ips() == ["149.154.167.220"]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# DoH auto-discovery
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _doh_answer(*ips: str) -> dict:
|
||||
"""Build a minimal DoH JSON response with A records."""
|
||||
return {"Answer": [{"type": 1, "data": ip} for ip in ips]}
|
||||
|
||||
|
||||
class FakeDoHClient:
|
||||
"""Mock httpx.AsyncClient for DoH queries."""
|
||||
|
||||
def __init__(self, responses: dict):
|
||||
# responses: URL prefix → (status, json_body) | Exception
|
||||
self._responses = responses
|
||||
self.requests_made: list[dict] = []
|
||||
|
||||
@staticmethod
|
||||
def _make_response(status, body, url):
|
||||
"""Build an httpx.Response with a request attached (needed for raise_for_status)."""
|
||||
request = httpx.Request("GET", url)
|
||||
return httpx.Response(status, json=body, request=request)
|
||||
|
||||
async def get(self, url, *, params=None, headers=None, **kwargs):
|
||||
self.requests_made.append({"url": url, "params": params, "headers": headers})
|
||||
for prefix, action in self._responses.items():
|
||||
if url.startswith(prefix):
|
||||
if isinstance(action, Exception):
|
||||
raise action
|
||||
status, body = action
|
||||
return self._make_response(status, body, url)
|
||||
return self._make_response(200, {}, url)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class TestDiscoverFallbackIps:
|
||||
"""Tests for discover_fallback_ips() — DoH-based auto-discovery."""
|
||||
|
||||
def _patch_doh(self, monkeypatch, responses, system_dns_ips=None):
|
||||
"""Wire up fake DoH client and system DNS."""
|
||||
client = FakeDoHClient(responses)
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncClient", lambda **kw: client)
|
||||
|
||||
if system_dns_ips is not None:
|
||||
addrs = [(None, None, None, None, (ip, 443)) for ip in system_dns_ips]
|
||||
monkeypatch.setattr(tnet.socket, "getaddrinfo", lambda *a, **kw: addrs)
|
||||
else:
|
||||
def _fail(*a, **kw):
|
||||
raise OSError("dns failed")
|
||||
monkeypatch.setattr(tnet.socket, "getaddrinfo", _fail)
|
||||
return client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_and_cloudflare_ips_collected(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.167.220")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.167.221")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert "149.154.167.220" in ips
|
||||
assert "149.154.167.221" in ips
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_dns_ip_excluded(self, monkeypatch):
|
||||
"""The IP from system DNS is the one that doesn't work — exclude it."""
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.166.110", "149.154.167.220")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.166.110")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == ["149.154.167.220"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doh_results_deduplicated(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.167.220")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.167.220")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == ["149.154.167.220"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doh_timeout_falls_back_to_seed(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": httpx.TimeoutException("timeout"),
|
||||
"https://cloudflare-dns.com": httpx.TimeoutException("timeout"),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == tnet._SEED_FALLBACK_IPS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doh_connect_error_falls_back_to_seed(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": httpx.ConnectError("refused"),
|
||||
"https://cloudflare-dns.com": httpx.ConnectError("refused"),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == tnet._SEED_FALLBACK_IPS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doh_malformed_json_falls_back_to_seed(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, {"Status": 0}), # no Answer key
|
||||
"https://cloudflare-dns.com": (200, {"garbage": True}),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == tnet._SEED_FALLBACK_IPS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_one_provider_fails_other_succeeds(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": httpx.TimeoutException("timeout"),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.167.220")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == ["149.154.167.220"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_dns_failure_keeps_all_doh_ips(self, monkeypatch):
|
||||
"""If system DNS fails, nothing gets excluded — all DoH IPs kept."""
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.166.110", "149.154.167.220")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer()),
|
||||
}, system_dns_ips=None) # triggers OSError
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert "149.154.166.110" in ips
|
||||
assert "149.154.167.220" in ips
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_doh_ips_same_as_system_dns_uses_seed(self, monkeypatch):
|
||||
"""DoH returns only the same blocked IP — seed list is the fallback."""
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.166.110")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.166.110")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == tnet._SEED_FALLBACK_IPS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cloudflare_gets_accept_header(self, monkeypatch):
|
||||
client = self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.167.220")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.167.221")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
await tnet.discover_fallback_ips()
|
||||
|
||||
cf_reqs = [r for r in client.requests_made if "cloudflare" in r["url"]]
|
||||
assert cf_reqs
|
||||
assert cf_reqs[0]["headers"]["Accept"] == "application/dns-json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_a_records_ignored(self, monkeypatch):
|
||||
"""AAAA records (type 28) and CNAME (type 5) should be skipped."""
|
||||
answer = {
|
||||
"Answer": [
|
||||
{"type": 5, "data": "telegram.org"}, # CNAME
|
||||
{"type": 28, "data": "2001:67c:4e8:f004::9"}, # AAAA
|
||||
{"type": 1, "data": "149.154.167.220"}, # A ✓
|
||||
]
|
||||
}
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, answer),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer()),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == ["149.154.167.220"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_ip_in_doh_response_skipped(self, monkeypatch):
|
||||
answer = {"Answer": [
|
||||
{"type": 1, "data": "not-an-ip"},
|
||||
{"type": 1, "data": "149.154.167.220"},
|
||||
]}
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, answer),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer()),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == ["149.154.167.220"]
|
||||
@@ -27,7 +27,7 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
@@ -36,6 +36,14 @@ _ensure_telegram_mock()
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_auto_discovery(monkeypatch):
|
||||
"""Disable DoH auto-discovery so connect() uses the plain builder chain."""
|
||||
async def _noop():
|
||||
return []
|
||||
monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop)
|
||||
|
||||
|
||||
def _make_adapter() -> TelegramAdapter:
|
||||
return TelegramAdapter(PlatformConfig(enabled=True, token="test-token"))
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def _ensure_telegram_mock():
|
||||
mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
mod.constants.ChatType.CHANNEL = "channel"
|
||||
mod.constants.ChatType.PRIVATE = "private"
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
"""Tests for Telegram send() thread_id fallback.
|
||||
|
||||
When message_thread_id points to a non-existent thread, Telegram returns
|
||||
BadRequest('Message thread not found'). Since BadRequest is a subclass of
|
||||
NetworkError in python-telegram-bot, the old retry loop treated this as a
|
||||
transient error and retried 3 times before silently failing — killing all
|
||||
tool progress messages, streaming responses, and typing indicators.
|
||||
|
||||
The fix detects "thread not found" BadRequest errors and retries the send
|
||||
WITHOUT message_thread_id so the message still reaches the chat.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
from gateway.platforms.base import SendResult
|
||||
|
||||
|
||||
# ── Fake telegram.error hierarchy ──────────────────────────────────────
|
||||
# Mirrors the real python-telegram-bot hierarchy:
|
||||
# BadRequest → NetworkError → TelegramError → Exception
|
||||
|
||||
|
||||
class FakeNetworkError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FakeBadRequest(FakeNetworkError):
|
||||
pass
|
||||
|
||||
|
||||
# Build a fake telegram module tree so the adapter's internal imports work
|
||||
_fake_telegram = types.ModuleType("telegram")
|
||||
_fake_telegram_error = types.ModuleType("telegram.error")
|
||||
_fake_telegram_error.NetworkError = FakeNetworkError
|
||||
_fake_telegram_error.BadRequest = FakeBadRequest
|
||||
_fake_telegram.error = _fake_telegram_error
|
||||
_fake_telegram_constants = types.ModuleType("telegram.constants")
|
||||
_fake_telegram_constants.ParseMode = SimpleNamespace(MARKDOWN_V2="MarkdownV2")
|
||||
_fake_telegram.constants = _fake_telegram_constants
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _inject_fake_telegram(monkeypatch):
|
||||
"""Inject fake telegram modules so the adapter can import from them."""
|
||||
monkeypatch.setitem(sys.modules, "telegram", _fake_telegram)
|
||||
monkeypatch.setitem(sys.modules, "telegram.error", _fake_telegram_error)
|
||||
monkeypatch.setitem(sys.modules, "telegram.constants", _fake_telegram_constants)
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
adapter = object.__new__(TelegramAdapter)
|
||||
adapter._config = config
|
||||
adapter._platform = Platform.TELEGRAM
|
||||
adapter._connected = True
|
||||
adapter._dm_topics = {}
|
||||
adapter._dm_topics_config = []
|
||||
adapter._reply_to_mode = "first"
|
||||
adapter._fallback_ips = []
|
||||
adapter._polling_conflict_count = 0
|
||||
adapter._polling_network_error_count = 0
|
||||
adapter._polling_error_callback_ref = None
|
||||
adapter.platform = Platform.TELEGRAM
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_without_thread_on_thread_not_found():
|
||||
"""When message_thread_id causes 'thread not found', retry without it."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_send_message(**kwargs):
|
||||
call_log.append(dict(kwargs))
|
||||
tid = kwargs.get("message_thread_id")
|
||||
if tid is not None:
|
||||
raise FakeBadRequest("Message thread not found")
|
||||
return SimpleNamespace(message_id=42)
|
||||
|
||||
adapter._bot = SimpleNamespace(send_message=mock_send_message)
|
||||
|
||||
result = await adapter.send(
|
||||
chat_id="123",
|
||||
content="test message",
|
||||
metadata={"thread_id": "99999"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "42"
|
||||
# First call has thread_id, second call retries without
|
||||
assert len(call_log) == 2
|
||||
assert call_log[0]["message_thread_id"] == 99999
|
||||
assert call_log[1]["message_thread_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_raises_on_other_bad_request():
|
||||
"""Non-thread BadRequest errors should NOT be retried — they fail immediately."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
async def mock_send_message(**kwargs):
|
||||
raise FakeBadRequest("Chat not found")
|
||||
|
||||
adapter._bot = SimpleNamespace(send_message=mock_send_message)
|
||||
|
||||
result = await adapter.send(
|
||||
chat_id="123",
|
||||
content="test message",
|
||||
metadata={"thread_id": "99999"},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "Chat not found" in result.error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_without_thread_id_unaffected():
|
||||
"""Normal sends without thread_id should work as before."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_send_message(**kwargs):
|
||||
call_log.append(dict(kwargs))
|
||||
return SimpleNamespace(message_id=100)
|
||||
|
||||
adapter._bot = SimpleNamespace(send_message=mock_send_message)
|
||||
|
||||
result = await adapter.send(
|
||||
chat_id="123",
|
||||
content="test message",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert len(call_log) == 1
|
||||
assert call_log[0]["message_thread_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_network_errors_normally():
|
||||
"""Real transient network errors (not BadRequest) should still be retried."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
attempt = [0]
|
||||
|
||||
async def mock_send_message(**kwargs):
|
||||
attempt[0] += 1
|
||||
if attempt[0] < 3:
|
||||
raise FakeNetworkError("Connection reset")
|
||||
return SimpleNamespace(message_id=200)
|
||||
|
||||
adapter._bot = SimpleNamespace(send_message=mock_send_message)
|
||||
|
||||
result = await adapter.send(
|
||||
chat_id="123",
|
||||
content="test message",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert attempt[0] == 3 # Two retries then success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_fallback_only_fires_once():
|
||||
"""After clearing thread_id, subsequent chunks should also use None."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_send_message(**kwargs):
|
||||
call_log.append(dict(kwargs))
|
||||
tid = kwargs.get("message_thread_id")
|
||||
if tid is not None:
|
||||
raise FakeBadRequest("Message thread not found")
|
||||
return SimpleNamespace(message_id=42)
|
||||
|
||||
adapter._bot = SimpleNamespace(send_message=mock_send_message)
|
||||
|
||||
# Send a long message that gets split into chunks
|
||||
long_msg = "A" * 5000 # Exceeds Telegram's 4096 limit
|
||||
result = await adapter.send(
|
||||
chat_id="123",
|
||||
content=long_msg,
|
||||
metadata={"thread_id": "99999"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
# First chunk: attempt with thread → fail → retry without → succeed
|
||||
# Second chunk: should use thread_id=None directly (effective_thread_id
|
||||
# was cleared per-chunk but the metadata doesn't change between chunks)
|
||||
# The key point: the message was delivered despite the invalid thread
|
||||
@@ -94,7 +94,7 @@ class TestOfferOpenclawMigration:
|
||||
fake_mod.Migrator.assert_called_once()
|
||||
call_kwargs = fake_mod.Migrator.call_args[1]
|
||||
assert call_kwargs["execute"] is True
|
||||
assert call_kwargs["overwrite"] is False
|
||||
assert call_kwargs["overwrite"] is True
|
||||
assert call_kwargs["migrate_secrets"] is True
|
||||
assert call_kwargs["preset_name"] == "full"
|
||||
fake_migrator.migrate.assert_called_once()
|
||||
@@ -285,3 +285,182 @@ class TestSetupWizardOpenclawIntegration:
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
mock_migration.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_section_config_summary / _skip_configured_section — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSectionConfigSummary:
|
||||
"""Test the _get_section_config_summary helper."""
|
||||
|
||||
def test_model_returns_none_without_api_key(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "model")
|
||||
assert result is None
|
||||
|
||||
def test_model_returns_summary_with_api_key(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"model": "openai/gpt-4"}, "model"
|
||||
)
|
||||
assert result == "openai/gpt-4"
|
||||
|
||||
def test_model_returns_dict_default_key(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENAI_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"model": {"default": "claude-opus-4", "provider": "anthropic"}},
|
||||
"model",
|
||||
)
|
||||
assert result == "claude-opus-4"
|
||||
|
||||
def test_terminal_always_returns(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"terminal": {"backend": "docker"}}, "terminal"
|
||||
)
|
||||
assert result == "backend: docker"
|
||||
|
||||
def test_agent_always_returns(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"agent": {"max_turns": 120}}, "agent"
|
||||
)
|
||||
assert result == "max turns: 120"
|
||||
|
||||
def test_gateway_returns_none_without_tokens(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert result is None
|
||||
|
||||
def test_gateway_lists_platforms(self):
|
||||
def env_side(key):
|
||||
if key == "TELEGRAM_BOT_TOKEN":
|
||||
return "tok123"
|
||||
if key == "DISCORD_BOT_TOKEN":
|
||||
return "disc456"
|
||||
return ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert "Telegram" in result
|
||||
assert "Discord" in result
|
||||
|
||||
def test_tools_returns_none_without_keys(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "tools")
|
||||
assert result is None
|
||||
|
||||
def test_tools_lists_configured(self):
|
||||
def env_side(key):
|
||||
return "key" if key == "BROWSERBASE_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "tools")
|
||||
assert "Browser" in result
|
||||
|
||||
|
||||
class TestSkipConfiguredSection:
|
||||
"""Test the _skip_configured_section helper."""
|
||||
|
||||
def test_returns_false_when_not_configured(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._skip_configured_section({}, "model", "Model")
|
||||
assert result is False
|
||||
|
||||
def test_returns_true_when_user_skips(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=False),
|
||||
):
|
||||
result = setup_mod._skip_configured_section(
|
||||
{"model": "openai/gpt-4"}, "model", "Model"
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_returns_false_when_user_wants_reconfig(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=True),
|
||||
):
|
||||
result = setup_mod._skip_configured_section(
|
||||
{"model": "openai/gpt-4"}, "model", "Model"
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestSetupWizardSkipsConfiguredSections:
|
||||
"""After migration, already-configured sections should offer skip."""
|
||||
|
||||
def test_sections_skipped_when_migration_imported_settings(self, tmp_path):
|
||||
"""When migration ran and API key exists, model section should be skippable.
|
||||
|
||||
Simulates the real flow: get_env_value returns "" during the is_existing
|
||||
check (before migration), then returns a key after migration imported it.
|
||||
"""
|
||||
args = _first_time_args()
|
||||
|
||||
# Track whether migration has "run" — after it does, API key is available
|
||||
migration_done = {"value": False}
|
||||
|
||||
def env_side(key):
|
||||
if migration_done["value"] and key == "OPENROUTER_API_KEY":
|
||||
return "sk-xxx"
|
||||
return ""
|
||||
|
||||
def fake_migration(hermes_home):
|
||||
migration_done["value"] = True
|
||||
return True
|
||||
|
||||
reloaded_config = {"model": "openai/gpt-4"}
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(
|
||||
setup_mod, "load_config",
|
||||
side_effect=[{}, reloaded_config],
|
||||
),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "is_interactive_stdin", return_value=True),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
patch("builtins.input", return_value=""),
|
||||
# Migration succeeds and flips the env_side flag
|
||||
patch.object(
|
||||
setup_mod, "_offer_openclaw_migration",
|
||||
side_effect=fake_migration,
|
||||
),
|
||||
# User says No to all reconfig prompts
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=False),
|
||||
patch.object(setup_mod, "setup_model_provider") as mock_model,
|
||||
patch.object(setup_mod, "setup_terminal_backend") as mock_terminal,
|
||||
patch.object(setup_mod, "setup_agent_settings") as mock_agent,
|
||||
patch.object(setup_mod, "setup_gateway") as mock_gateway,
|
||||
patch.object(setup_mod, "setup_tools") as mock_tools,
|
||||
patch.object(setup_mod, "save_config"),
|
||||
patch.object(setup_mod, "_print_setup_summary"),
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
# Model has API key → skip offered, user said No → section NOT called
|
||||
mock_model.assert_not_called()
|
||||
# Terminal/agent always have a summary → skip offered, user said No
|
||||
mock_terminal.assert_not_called()
|
||||
mock_agent.assert_not_called()
|
||||
# Gateway has no tokens (env_side returns "" for gateway keys) → section runs
|
||||
mock_gateway.assert_called_once()
|
||||
# Tools have no keys → section runs
|
||||
mock_tools.assert_called_once()
|
||||
|
||||
@@ -801,6 +801,48 @@ class TestConvertMessages:
|
||||
assert all(not (b.get("type") == "text" and b.get("text") == "") for b in assistant_blocks)
|
||||
assert any(b.get("type") == "tool_use" for b in assistant_blocks)
|
||||
|
||||
def test_empty_user_message_string_gets_placeholder(self):
|
||||
"""Empty user message strings should get '(empty message)' placeholder.
|
||||
|
||||
Anthropic rejects requests with empty user message content.
|
||||
Regression test for #3143 — Discord @mention-only messages.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "user", "content": ""},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[0]["content"] == "(empty message)"
|
||||
|
||||
def test_whitespace_only_user_message_gets_placeholder(self):
|
||||
"""Whitespace-only user messages should also get placeholder."""
|
||||
messages = [
|
||||
{"role": "user", "content": " \n\t "},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["content"] == "(empty message)"
|
||||
|
||||
def test_empty_user_message_list_gets_placeholder(self):
|
||||
"""Empty content list for user messages should get placeholder block."""
|
||||
messages = [
|
||||
{"role": "user", "content": []},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["role"] == "user"
|
||||
assert isinstance(result[0]["content"], list)
|
||||
assert len(result[0]["content"]) == 1
|
||||
assert result[0]["content"][0] == {"type": "text", "text": "(empty message)"}
|
||||
|
||||
def test_user_message_with_empty_text_blocks_gets_placeholder(self):
|
||||
"""User message with only empty text blocks should get placeholder."""
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "text", "text": ""}, {"type": "text", "text": " "}]},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["role"] == "user"
|
||||
assert isinstance(result[0]["content"], list)
|
||||
assert result[0]["content"] == [{"type": "text", "text": "(empty message)"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build kwargs
|
||||
|
||||
@@ -217,10 +217,17 @@ def test_529_overloaded_is_retried_and_recovers(monkeypatch):
|
||||
|
||||
|
||||
def test_429_exhausts_all_retries_before_raising(monkeypatch):
|
||||
"""429 must retry max_retries times, not abort on first attempt."""
|
||||
"""429 must retry max_retries times, then return a failed result.
|
||||
|
||||
The agent no longer re-raises after exhausting retries — it returns a
|
||||
result dict with the error in final_response. This changed when the
|
||||
fallback-provider feature was added (the agent tries a fallback before
|
||||
giving up, and returns a result dict either way).
|
||||
"""
|
||||
agent_cls = _make_agent_cls(_RateLimitError) # always fails
|
||||
with pytest.raises(_RateLimitError):
|
||||
_run_with_agent(monkeypatch, agent_cls)
|
||||
result = _run_with_agent(monkeypatch, agent_cls)
|
||||
resp = str(result.get("final_response", ""))
|
||||
assert "429" in resp or "retries" in resp.lower()
|
||||
|
||||
|
||||
def test_400_bad_request_is_non_retryable(monkeypatch):
|
||||
|
||||
@@ -96,6 +96,59 @@ class TestVerboseAndToolProgress:
|
||||
assert cli.tool_progress_mode in ("off", "new", "all", "verbose")
|
||||
|
||||
|
||||
class TestBusyInputMode:
|
||||
def test_default_busy_input_mode_is_interrupt(self):
|
||||
cli = _make_cli()
|
||||
assert cli.busy_input_mode == "interrupt"
|
||||
|
||||
def test_busy_input_mode_queue_is_honored(self):
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "queue"}})
|
||||
assert cli.busy_input_mode == "queue"
|
||||
|
||||
def test_unknown_busy_input_mode_falls_back_to_interrupt(self):
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "bogus"}})
|
||||
assert cli.busy_input_mode == "interrupt"
|
||||
|
||||
def test_queue_command_works_while_busy(self):
|
||||
"""When agent is running, /queue should still put the prompt in _pending_input."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
cli.process_command("/queue follow up")
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
|
||||
def test_queue_command_works_while_idle(self):
|
||||
"""When agent is idle, /queue should still queue (not reject)."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = False
|
||||
cli.process_command("/queue follow up")
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
|
||||
def test_queue_mode_routes_busy_enter_to_pending(self):
|
||||
"""In queue mode, Enter while busy should go to _pending_input, not _interrupt_queue."""
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "queue"}})
|
||||
cli._agent_running = True
|
||||
# Simulate what handle_enter does for non-command input while busy
|
||||
text = "follow up"
|
||||
if cli.busy_input_mode == "queue":
|
||||
cli._pending_input.put(text)
|
||||
else:
|
||||
cli._interrupt_queue.put(text)
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
assert cli._interrupt_queue.empty()
|
||||
|
||||
def test_interrupt_mode_routes_busy_enter_to_interrupt(self):
|
||||
"""In interrupt mode (default), Enter while busy goes to _interrupt_queue."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
text = "redirect"
|
||||
if cli.busy_input_mode == "queue":
|
||||
cli._pending_input.put(text)
|
||||
else:
|
||||
cli._interrupt_queue.put(text)
|
||||
assert cli._interrupt_queue.get_nowait() == "redirect"
|
||||
assert cli._pending_input.empty()
|
||||
|
||||
|
||||
class TestSingleQueryState:
|
||||
def test_voice_and_interrupt_state_initialized_before_run(self):
|
||||
"""Single-query mode calls chat() without going through run()."""
|
||||
|
||||
@@ -182,3 +182,94 @@ class TestCLIUsageReport:
|
||||
assert "Total cost:" in output
|
||||
assert "n/a" in output
|
||||
assert "Pricing unknown for glm-5" in output
|
||||
|
||||
|
||||
class TestStatusBarWidthSource:
|
||||
"""Ensure status bar fragments don't overflow the terminal width."""
|
||||
|
||||
def _make_wide_cli(self):
|
||||
from datetime import datetime, timedelta
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=100_000,
|
||||
completion_tokens=5_000,
|
||||
total_tokens=105_000,
|
||||
api_calls=20,
|
||||
context_tokens=100_000,
|
||||
context_length=200_000,
|
||||
)
|
||||
cli_obj._status_bar_visible = True
|
||||
return cli_obj
|
||||
|
||||
def test_fragments_fit_within_announced_width(self):
|
||||
"""Total fragment text length must not exceed the width used to build them."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
for width in (40, 52, 76, 80, 120, 200):
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=width)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app):
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
|
||||
total_text = "".join(text for _, text in frags)
|
||||
assert len(total_text) <= width + 4, ( # +4 for minor padding chars
|
||||
f"At width={width}, fragment total {len(total_text)} chars overflows "
|
||||
f"({total_text!r})"
|
||||
)
|
||||
|
||||
def test_fragments_use_pt_width_over_shutil(self):
|
||||
"""When prompt_toolkit reports a width, shutil.get_terminal_size must not be used."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=120)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app) as mock_get_app, \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
cli_obj._get_status_bar_fragments()
|
||||
|
||||
mock_shutil.assert_not_called()
|
||||
|
||||
def test_fragments_fall_back_to_shutil_when_no_app(self):
|
||||
"""Outside a TUI context (no running app), shutil must be used as fallback."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", side_effect=Exception("no app")), \
|
||||
patch("shutil.get_terminal_size", return_value=MagicMock(columns=100)) as mock_shutil:
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
|
||||
mock_shutil.assert_called()
|
||||
assert len(frags) > 0
|
||||
|
||||
def test_build_status_bar_text_uses_pt_width(self):
|
||||
"""_build_status_bar_text() must also prefer prompt_toolkit width."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=80)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app), \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
text = cli_obj._build_status_bar_text() # no explicit width
|
||||
|
||||
mock_shutil.assert_not_called()
|
||||
assert isinstance(text, str)
|
||||
assert len(text) > 0
|
||||
|
||||
def test_explicit_width_skips_pt_lookup(self):
|
||||
"""An explicit width= argument must bypass both PT and shutil lookups."""
|
||||
from unittest.mock import patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
with patch("prompt_toolkit.application.get_app") as mock_get_app, \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
text = cli_obj._build_status_bar_text(width=100)
|
||||
|
||||
mock_get_app.assert_not_called()
|
||||
mock_shutil.assert_not_called()
|
||||
assert len(text) > 0
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Tests that _try_activate_fallback updates the context compressor."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
|
||||
def _make_agent_with_compressor() -> AIAgent:
|
||||
"""Build a minimal AIAgent with a context_compressor, skipping __init__."""
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
|
||||
# Primary model settings
|
||||
agent.model = "primary-model"
|
||||
agent.provider = "openrouter"
|
||||
agent.base_url = "https://openrouter.ai/api/v1"
|
||||
agent.api_key = "sk-primary"
|
||||
agent.api_mode = "chat_completions"
|
||||
agent.client = MagicMock()
|
||||
agent.quiet_mode = True
|
||||
|
||||
# Fallback config
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
|
||||
# Context compressor with primary model values
|
||||
compressor = ContextCompressor(
|
||||
model="primary-model",
|
||||
threshold_percent=0.50,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key="sk-primary",
|
||||
provider="openrouter",
|
||||
quiet_mode=True,
|
||||
)
|
||||
agent.context_compressor = compressor
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.resolve_provider_client")
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=128_000)
|
||||
def test_compressor_updated_on_fallback(mock_ctx_len, mock_resolve):
|
||||
"""After fallback activation, the compressor must reflect the fallback model."""
|
||||
agent = _make_agent_with_compressor()
|
||||
|
||||
assert agent.context_compressor.model == "primary-model"
|
||||
|
||||
fb_client = MagicMock()
|
||||
fb_client.base_url = "https://api.openai.com/v1"
|
||||
fb_client.api_key = "sk-fallback"
|
||||
mock_resolve.return_value = (fb_client, None)
|
||||
|
||||
agent._is_direct_openai_url = lambda url: "api.openai.com" in url
|
||||
agent._emit_status = lambda msg: None
|
||||
|
||||
result = agent._try_activate_fallback()
|
||||
|
||||
assert result is True
|
||||
assert agent._fallback_activated is True
|
||||
|
||||
c = agent.context_compressor
|
||||
assert c.model == "gpt-4o"
|
||||
assert c.base_url == "https://api.openai.com/v1"
|
||||
assert c.api_key == "sk-fallback"
|
||||
assert c.provider == "openai"
|
||||
assert c.context_length == 128_000
|
||||
assert c.threshold_tokens == int(128_000 * c.threshold_percent)
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.resolve_provider_client")
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=128_000)
|
||||
def test_compressor_not_present_does_not_crash(mock_ctx_len, mock_resolve):
|
||||
"""If the agent has no compressor, fallback should still succeed."""
|
||||
agent = _make_agent_with_compressor()
|
||||
agent.context_compressor = None
|
||||
|
||||
fb_client = MagicMock()
|
||||
fb_client.base_url = "https://api.openai.com/v1"
|
||||
fb_client.api_key = "sk-fallback"
|
||||
mock_resolve.return_value = (fb_client, None)
|
||||
|
||||
agent._is_direct_openai_url = lambda url: "api.openai.com" in url
|
||||
agent._emit_status = lambda msg: None
|
||||
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
@@ -472,6 +472,7 @@ class TestInlineThinkBlockExtraction(unittest.TestCase):
|
||||
agent._extract_reasoning = AIAgent._extract_reasoning.__get__(agent)
|
||||
agent.verbose_logging = False
|
||||
agent.reasoning_callback = None
|
||||
agent.stream_delta_callback = None # non-streaming by default
|
||||
return agent
|
||||
|
||||
def test_single_think_block_extracted(self):
|
||||
@@ -605,5 +606,159 @@ class TestEndToEndPipeline(unittest.TestCase):
|
||||
self.assertIsNone(result["last_reasoning"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Duplicate reasoning box prevention (Bug fix: 3 boxes for 1 reasoning)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReasoningDeltasFiredFlag(unittest.TestCase):
|
||||
"""_build_assistant_message should not re-fire reasoning_callback when
|
||||
reasoning was already streamed via _fire_reasoning_delta."""
|
||||
|
||||
def _make_agent(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent.reasoning_callback = None
|
||||
agent.stream_delta_callback = None
|
||||
agent._reasoning_deltas_fired = False
|
||||
agent.verbose_logging = False
|
||||
return agent
|
||||
|
||||
def test_fire_reasoning_delta_sets_flag(self):
|
||||
agent = self._make_agent()
|
||||
captured = []
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
self.assertFalse(agent._reasoning_deltas_fired)
|
||||
agent._fire_reasoning_delta("thinking...")
|
||||
self.assertTrue(agent._reasoning_deltas_fired)
|
||||
self.assertEqual(captured, ["thinking..."])
|
||||
|
||||
def test_build_assistant_message_skips_callback_when_already_streamed(self):
|
||||
"""When streaming already fired reasoning deltas, the post-stream
|
||||
_build_assistant_message should NOT re-fire the callback."""
|
||||
agent = self._make_agent()
|
||||
captured = []
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
agent.stream_delta_callback = lambda t: None # streaming is active
|
||||
|
||||
# Simulate streaming having fired reasoning
|
||||
agent._reasoning_deltas_fired = True
|
||||
|
||||
msg = SimpleNamespace(
|
||||
content="I'll merge that.",
|
||||
tool_calls=None,
|
||||
reasoning_content="Let me merge the PR.",
|
||||
reasoning=None,
|
||||
reasoning_details=None,
|
||||
)
|
||||
agent._build_assistant_message(msg, "stop")
|
||||
|
||||
# Callback should NOT have been fired again
|
||||
self.assertEqual(captured, [])
|
||||
|
||||
def test_build_assistant_message_skips_callback_when_streaming_active(self):
|
||||
"""When streaming is active, callback should NEVER fire from
|
||||
_build_assistant_message — reasoning was already displayed during the
|
||||
stream (either via reasoning_content deltas or content tag extraction).
|
||||
Any missed reasoning is caught by the CLI post-response fallback."""
|
||||
agent = self._make_agent()
|
||||
captured = []
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
agent.stream_delta_callback = lambda t: None # streaming active
|
||||
|
||||
# Even though _reasoning_deltas_fired is False (reasoning came through
|
||||
# content tags, not reasoning_content deltas), callback should not fire
|
||||
agent._reasoning_deltas_fired = False
|
||||
|
||||
msg = SimpleNamespace(
|
||||
content="I'll merge that.",
|
||||
tool_calls=None,
|
||||
reasoning_content="Let me merge the PR.",
|
||||
reasoning=None,
|
||||
reasoning_details=None,
|
||||
)
|
||||
agent._build_assistant_message(msg, "stop")
|
||||
|
||||
# Callback should NOT fire — streaming is active
|
||||
self.assertEqual(captured, [])
|
||||
|
||||
def test_build_assistant_message_fires_callback_without_streaming(self):
|
||||
"""When no streaming is active, callback always fires for structured
|
||||
reasoning."""
|
||||
agent = self._make_agent()
|
||||
captured = []
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
# No streaming
|
||||
agent.stream_delta_callback = None
|
||||
agent._reasoning_deltas_fired = False
|
||||
|
||||
msg = SimpleNamespace(
|
||||
content="I'll merge that.",
|
||||
tool_calls=None,
|
||||
reasoning_content="Let me merge the PR.",
|
||||
reasoning=None,
|
||||
reasoning_details=None,
|
||||
)
|
||||
agent._build_assistant_message(msg, "stop")
|
||||
|
||||
self.assertEqual(captured, ["Let me merge the PR."])
|
||||
|
||||
|
||||
class TestReasoningShownThisTurnFlag(unittest.TestCase):
|
||||
"""Post-response reasoning display should be suppressed when reasoning
|
||||
was already shown during streaming in a tool-calling loop."""
|
||||
|
||||
def _make_cli(self):
|
||||
from cli import HermesCLI
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli.show_reasoning = True
|
||||
cli.streaming_enabled = True
|
||||
cli._stream_box_opened = False
|
||||
cli._reasoning_box_opened = False
|
||||
cli._reasoning_stream_started = False
|
||||
cli._reasoning_shown_this_turn = False
|
||||
cli._reasoning_buf = ""
|
||||
cli._stream_buf = ""
|
||||
cli._stream_started = False
|
||||
cli._stream_text_ansi = ""
|
||||
cli._stream_prefilt = ""
|
||||
cli._in_reasoning_block = False
|
||||
cli._reasoning_preview_buf = ""
|
||||
return cli
|
||||
|
||||
@patch("cli._cprint")
|
||||
def test_streaming_reasoning_sets_turn_flag(self, mock_cprint):
|
||||
cli = self._make_cli()
|
||||
self.assertFalse(cli._reasoning_shown_this_turn)
|
||||
cli._stream_reasoning_delta("Thinking about it...")
|
||||
self.assertTrue(cli._reasoning_shown_this_turn)
|
||||
|
||||
@patch("cli._cprint")
|
||||
def test_turn_flag_survives_reset_stream_state(self, mock_cprint):
|
||||
"""_reasoning_shown_this_turn must NOT be cleared by
|
||||
_reset_stream_state (called at intermediate turn boundaries)."""
|
||||
cli = self._make_cli()
|
||||
cli._stream_reasoning_delta("Thinking...")
|
||||
self.assertTrue(cli._reasoning_shown_this_turn)
|
||||
|
||||
# Simulate intermediate turn boundary (tool call)
|
||||
cli._reset_stream_state()
|
||||
|
||||
# Flag must persist
|
||||
self.assertTrue(cli._reasoning_shown_this_turn)
|
||||
|
||||
@patch("cli._cprint")
|
||||
def test_turn_flag_cleared_before_new_turn(self, mock_cprint):
|
||||
"""The turn flag should be reset at the start of a new user turn.
|
||||
This happens outside _reset_stream_state, at the call site."""
|
||||
cli = self._make_cli()
|
||||
cli._reasoning_shown_this_turn = True
|
||||
|
||||
# Simulate new user turn setup
|
||||
cli._reset_stream_state()
|
||||
cli._reasoning_shown_this_turn = False # done by process_input
|
||||
|
||||
self.assertFalse(cli._reasoning_shown_this_turn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -584,6 +584,38 @@ class TestBuildSystemPrompt:
|
||||
# Should contain current date info like "Conversation started:"
|
||||
assert "Conversation started:" in prompt
|
||||
|
||||
def test_skills_prompt_derives_available_toolsets_from_loaded_tools(self):
|
||||
tools = _make_tool_defs("web_search", "skills_list", "skill_view", "skill_manage")
|
||||
toolset_map = {
|
||||
"web_search": "web",
|
||||
"skills_list": "skills",
|
||||
"skill_view": "skills",
|
||||
"skill_manage": "skills",
|
||||
}
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=tools),
|
||||
patch(
|
||||
"run_agent.check_toolset_requirements",
|
||||
side_effect=AssertionError("should not re-check toolset requirements"),
|
||||
),
|
||||
patch("run_agent.get_toolset_for_tool", create=True, side_effect=toolset_map.get),
|
||||
patch("run_agent.build_skills_system_prompt", return_value="SKILLS_PROMPT") as mock_skills,
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-k...7890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
prompt = agent._build_system_prompt()
|
||||
|
||||
assert "SKILLS_PROMPT" in prompt
|
||||
assert mock_skills.call_args.kwargs["available_tools"] == set(toolset_map)
|
||||
assert mock_skills.call_args.kwargs["available_toolsets"] == {"web", "skills"}
|
||||
|
||||
|
||||
class TestInvalidateSystemPrompt:
|
||||
def test_clears_cache(self, agent):
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
"""Tests for config.get() null-coalescing in tool configuration.
|
||||
|
||||
YAML ``null`` values (or ``~``) for a present key make ``dict.get(key, default)``
|
||||
return ``None`` instead of the default — calling ``.lower()`` on that raises
|
||||
``AttributeError``. These tests verify the ``or`` coalescing guards.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
|
||||
# ── TTS tool ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestTTSProviderNullGuard:
|
||||
"""tools/tts_tool.py — _get_provider()"""
|
||||
|
||||
def test_explicit_null_provider_returns_default(self):
|
||||
"""YAML ``tts: {provider: null}`` should fall back to default."""
|
||||
from tools.tts_tool import _get_provider, DEFAULT_PROVIDER
|
||||
|
||||
result = _get_provider({"provider": None})
|
||||
assert result == DEFAULT_PROVIDER.lower().strip()
|
||||
|
||||
def test_missing_provider_returns_default(self):
|
||||
"""No ``provider`` key at all should also return default."""
|
||||
from tools.tts_tool import _get_provider, DEFAULT_PROVIDER
|
||||
|
||||
result = _get_provider({})
|
||||
assert result == DEFAULT_PROVIDER.lower().strip()
|
||||
|
||||
def test_valid_provider_passed_through(self):
|
||||
from tools.tts_tool import _get_provider
|
||||
|
||||
result = _get_provider({"provider": "OPENAI"})
|
||||
assert result == "openai"
|
||||
|
||||
|
||||
# ── Web tools ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestWebBackendNullGuard:
|
||||
"""tools/web_tools.py — _get_backend()"""
|
||||
|
||||
@patch("tools.web_tools._load_web_config", return_value={"backend": None})
|
||||
def test_explicit_null_backend_does_not_crash(self, _cfg):
|
||||
"""YAML ``web: {backend: null}`` should not raise AttributeError."""
|
||||
from tools.web_tools import _get_backend
|
||||
|
||||
# Should not raise — the exact return depends on env key fallback
|
||||
result = _get_backend()
|
||||
assert isinstance(result, str)
|
||||
|
||||
@patch("tools.web_tools._load_web_config", return_value={})
|
||||
def test_missing_backend_does_not_crash(self, _cfg):
|
||||
from tools.web_tools import _get_backend
|
||||
|
||||
result = _get_backend()
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# ── MCP tool ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestMCPAuthNullGuard:
|
||||
"""tools/mcp_tool.py — MCPServerTask.__init__() auth config line"""
|
||||
|
||||
def test_explicit_null_auth_does_not_crash(self):
|
||||
"""YAML ``auth: null`` in MCP server config should not raise."""
|
||||
# Test the expression directly — MCPServerTask.__init__ has many deps
|
||||
config = {"auth": None, "timeout": 30}
|
||||
auth_type = (config.get("auth") or "").lower().strip()
|
||||
assert auth_type == ""
|
||||
|
||||
def test_missing_auth_defaults_to_empty(self):
|
||||
config = {"timeout": 30}
|
||||
auth_type = (config.get("auth") or "").lower().strip()
|
||||
assert auth_type == ""
|
||||
|
||||
def test_valid_auth_passed_through(self):
|
||||
config = {"auth": "OAUTH", "timeout": 30}
|
||||
auth_type = (config.get("auth") or "").lower().strip()
|
||||
assert auth_type == "oauth"
|
||||
|
||||
|
||||
# ── Trajectory compressor ─────────────────────────────────────────────────
|
||||
|
||||
class TestTrajectoryCompressorNullGuard:
|
||||
"""trajectory_compressor.py — _detect_provider() and config loading"""
|
||||
|
||||
def test_null_base_url_does_not_crash(self):
|
||||
"""base_url=None should not crash _detect_provider()."""
|
||||
from trajectory_compressor import CompressionConfig, TrajectoryCompressor
|
||||
|
||||
config = CompressionConfig()
|
||||
config.base_url = None
|
||||
|
||||
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
|
||||
compressor.config = config
|
||||
|
||||
# Should not raise AttributeError; returns empty string (no match)
|
||||
result = compressor._detect_provider()
|
||||
assert result == ""
|
||||
|
||||
def test_config_loading_null_base_url_keeps_default(self):
|
||||
"""YAML ``summarization: {base_url: null}`` should keep default."""
|
||||
from trajectory_compressor import CompressionConfig
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
config = CompressionConfig()
|
||||
data = {"summarization": {"base_url": None}}
|
||||
|
||||
config.base_url = data["summarization"].get("base_url") or config.base_url
|
||||
assert config.base_url == OPENROUTER_BASE_URL
|
||||
@@ -185,3 +185,71 @@ class TestApplyUpdate:
|
||||
' result = 1\n'
|
||||
' return result + 1'
|
||||
)
|
||||
|
||||
|
||||
class TestAdditionOnlyHunks:
|
||||
"""Regression tests for #3081 — addition-only hunks were silently dropped."""
|
||||
|
||||
def test_addition_only_hunk_with_context_hint(self):
|
||||
"""A hunk with only + lines should insert at the context hint location."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: src/app.py
|
||||
@@ def main @@
|
||||
+def helper():
|
||||
+ return 42
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
assert len(ops[0].hunks) == 1
|
||||
|
||||
hunk = ops[0].hunks[0]
|
||||
# All lines should be additions
|
||||
assert all(l.prefix == '+' for l in hunk.lines)
|
||||
|
||||
# Apply to a file that contains the context hint
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file(self, path, **kw):
|
||||
return SimpleNamespace(
|
||||
content="def main():\n pass\n",
|
||||
error=None,
|
||||
)
|
||||
def write_file(self, path, content):
|
||||
self.written = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
assert result.success is True
|
||||
assert "def helper():" in file_ops.written
|
||||
assert "return 42" in file_ops.written
|
||||
|
||||
def test_addition_only_hunk_without_context_hint(self):
|
||||
"""A hunk with only + lines and no context hint appends at end of file."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: src/app.py
|
||||
+def new_func():
|
||||
+ return True
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file(self, path, **kw):
|
||||
return SimpleNamespace(
|
||||
content="existing = True\n",
|
||||
error=None,
|
||||
)
|
||||
def write_file(self, path, content):
|
||||
self.written = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
assert result.success is True
|
||||
assert file_ops.written.endswith("def new_func():\n return True\n")
|
||||
assert "existing = True" in file_ops.written
|
||||
|
||||
@@ -81,6 +81,33 @@ class TestGetDefinitions:
|
||||
assert len(defs) == 1
|
||||
assert defs[0]["function"]["name"] == "available"
|
||||
|
||||
def test_reuses_shared_check_fn_once_per_call(self):
|
||||
reg = ToolRegistry()
|
||||
calls = {"count": 0}
|
||||
|
||||
def shared_check():
|
||||
calls["count"] += 1
|
||||
return True
|
||||
|
||||
reg.register(
|
||||
name="first",
|
||||
toolset="shared",
|
||||
schema=_make_schema("first"),
|
||||
handler=_dummy_handler,
|
||||
check_fn=shared_check,
|
||||
)
|
||||
reg.register(
|
||||
name="second",
|
||||
toolset="shared",
|
||||
schema=_make_schema("second"),
|
||||
handler=_dummy_handler,
|
||||
check_fn=shared_check,
|
||||
)
|
||||
|
||||
defs = reg.get_definitions({"first", "second"})
|
||||
assert len(defs) == 2
|
||||
assert calls["count"] == 1
|
||||
|
||||
|
||||
class TestUnknownToolDispatch:
|
||||
def test_returns_error_json(self):
|
||||
|
||||
+1
-1
@@ -797,7 +797,7 @@ class MCPServerTask:
|
||||
"""
|
||||
self._config = config
|
||||
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
|
||||
self._auth_type = config.get("auth", "").lower().strip()
|
||||
self._auth_type = (config.get("auth") or "").lower().strip()
|
||||
|
||||
# Set up sampling handler if enabled and SDK types are available
|
||||
sampling_config = config.get("sampling", {})
|
||||
|
||||
@@ -419,6 +419,23 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
|
||||
if error:
|
||||
return False, f"Could not apply hunk: {error}"
|
||||
else:
|
||||
# Addition-only hunk (no context or removed lines).
|
||||
# Insert at the location indicated by the context hint, or at end of file.
|
||||
insert_text = '\n'.join(replace_lines)
|
||||
if hunk.context_hint:
|
||||
hint_pos = new_content.find(hunk.context_hint)
|
||||
if hint_pos != -1:
|
||||
# Insert after the line containing the context hint
|
||||
eol = new_content.find('\n', hint_pos)
|
||||
if eol != -1:
|
||||
new_content = new_content[:eol + 1] + insert_text + '\n' + new_content[eol + 1:]
|
||||
else:
|
||||
new_content = new_content + '\n' + insert_text
|
||||
else:
|
||||
new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n'
|
||||
else:
|
||||
new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n'
|
||||
|
||||
# Write new content
|
||||
write_result = file_ops.write_file(op.file_path, new_content)
|
||||
|
||||
+9
-6
@@ -98,19 +98,22 @@ class ToolRegistry:
|
||||
are included.
|
||||
"""
|
||||
result = []
|
||||
check_results: Dict[Callable, bool] = {}
|
||||
for name in sorted(tool_names):
|
||||
entry = self._tools.get(name)
|
||||
if not entry:
|
||||
continue
|
||||
if entry.check_fn:
|
||||
try:
|
||||
if not entry.check_fn():
|
||||
if entry.check_fn not in check_results:
|
||||
try:
|
||||
check_results[entry.check_fn] = bool(entry.check_fn())
|
||||
except Exception:
|
||||
check_results[entry.check_fn] = False
|
||||
if not quiet:
|
||||
logger.debug("Tool %s unavailable (check failed)", name)
|
||||
continue
|
||||
except Exception:
|
||||
logger.debug("Tool %s check raised; skipping", name)
|
||||
if not check_results[entry.check_fn]:
|
||||
if not quiet:
|
||||
logger.debug("Tool %s check raised; skipping", name)
|
||||
logger.debug("Tool %s unavailable (check failed)", name)
|
||||
continue
|
||||
result.append({"type": "function", "function": entry.schema})
|
||||
return result
|
||||
|
||||
+1
-1
@@ -102,7 +102,7 @@ def _load_tts_config() -> Dict[str, Any]:
|
||||
|
||||
def _get_provider(tts_config: Dict[str, Any]) -> str:
|
||||
"""Get the configured TTS provider name."""
|
||||
return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip()
|
||||
return (tts_config.get("provider") or DEFAULT_PROVIDER).lower().strip()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
|
||||
+1
-1
@@ -73,7 +73,7 @@ def _get_backend() -> str:
|
||||
Falls back to whichever API key is present for users who configured
|
||||
keys manually without running setup.
|
||||
"""
|
||||
configured = _load_web_config().get("backend", "").lower().strip()
|
||||
configured = (_load_web_config().get("backend") or "").lower().strip()
|
||||
if configured in ("parallel", "firecrawl", "tavily"):
|
||||
return configured
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ class CompressionConfig:
|
||||
# Summarization
|
||||
if 'summarization' in data:
|
||||
config.summarization_model = data['summarization'].get('model', config.summarization_model)
|
||||
config.base_url = data['summarization'].get('base_url', config.base_url)
|
||||
config.base_url = data['summarization'].get('base_url') or config.base_url
|
||||
config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env)
|
||||
config.temperature = data['summarization'].get('temperature', config.temperature)
|
||||
config.max_retries = data['summarization'].get('max_retries', config.max_retries)
|
||||
@@ -386,7 +386,7 @@ class TrajectoryCompressor:
|
||||
|
||||
def _detect_provider(self) -> str:
|
||||
"""Detect the provider name from the configured base_url."""
|
||||
url = self.config.base_url.lower()
|
||||
url = (self.config.base_url or "").lower()
|
||||
if "openrouter" in url:
|
||||
return "openrouter"
|
||||
if "nousresearch.com" in url:
|
||||
|
||||
Reference in New Issue
Block a user