Compare commits
23 Commits
fix/api-se
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2d9e4dd1f | ||
|
|
b8b1f24fd7 | ||
|
|
a2847ea7f0 | ||
|
|
58ca875e19 | ||
|
|
3f95e741a7 | ||
|
|
03396627a6 | ||
|
|
22cfad157b | ||
|
|
867eefdd9f | ||
|
|
a8df7f9964 | ||
|
|
1519c4d477 | ||
|
|
005786c55d | ||
|
|
ad764d3513 | ||
|
|
f008ee1019 | ||
|
|
60fdb58ce4 | ||
|
|
18d28c63a7 | ||
|
|
3c57eaf744 | ||
|
|
2d232c9991 | ||
|
|
0375b2a0d7 | ||
|
|
08fa326bb0 | ||
|
|
bde45f5a2a | ||
|
|
716e616d28 | ||
|
|
bdccdd67a1 | ||
|
|
148f46620f |
@@ -706,14 +706,21 @@ def convert_messages_to_anthropic(
|
||||
result.append({"role": "user", "content": [tool_result]})
|
||||
continue
|
||||
|
||||
# Regular user message
|
||||
# Regular user message — validate non-empty content (Anthropic rejects empty)
|
||||
if isinstance(content, list):
|
||||
converted_blocks = _convert_content_to_anthropic(content)
|
||||
result.append({
|
||||
"role": "user",
|
||||
"content": converted_blocks or [{"type": "text", "text": ""}],
|
||||
})
|
||||
# Check if all text blocks are empty
|
||||
if not converted_blocks or all(
|
||||
b.get("text", "").strip() == ""
|
||||
for b in converted_blocks
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
):
|
||||
converted_blocks = [{"type": "text", "text": "(empty message)"}]
|
||||
result.append({"role": "user", "content": converted_blocks})
|
||||
else:
|
||||
# Validate string content is non-empty
|
||||
if not content or (isinstance(content, str) and not content.strip()):
|
||||
content = "(empty message)"
|
||||
result.append({"role": "user", "content": content})
|
||||
|
||||
# Strip orphaned tool_use blocks (no matching tool_result follows)
|
||||
|
||||
@@ -693,7 +693,13 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
||||
is_oauth = _is_oauth_token(token)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get("anthropic", "claude-haiku-4-5-20251001")
|
||||
logger.debug("Auxiliary client: Anthropic native (%s) at %s (oauth=%s)", model, base_url, is_oauth)
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
try:
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
except ImportError:
|
||||
# The anthropic_adapter module imports fine but the SDK itself is
|
||||
# missing — build_anthropic_client raises ImportError at call time
|
||||
# when _anthropic_sdk is None. Treat as unavailable.
|
||||
return None, None
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model
|
||||
|
||||
|
||||
|
||||
@@ -231,7 +231,7 @@ class KawaiiSpinner:
|
||||
"analyzing", "computing", "synthesizing", "formulating", "brainstorming",
|
||||
]
|
||||
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots'):
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots', print_fn=None):
|
||||
self.message = message
|
||||
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots'])
|
||||
self.running = False
|
||||
@@ -239,12 +239,26 @@ class KawaiiSpinner:
|
||||
self.frame_idx = 0
|
||||
self.start_time = None
|
||||
self.last_line_len = 0
|
||||
# Optional callable to route all output through (e.g. a no-op for silent
|
||||
# background agents). When set, bypasses self._out entirely so that
|
||||
# agents with _print_fn overridden remain fully silent.
|
||||
self._print_fn = print_fn
|
||||
# Capture stdout NOW, before any redirect_stdout(devnull) from
|
||||
# child agents can replace sys.stdout with a black hole.
|
||||
self._out = sys.stdout
|
||||
|
||||
def _write(self, text: str, end: str = '\n', flush: bool = False):
|
||||
"""Write to the stdout captured at spinner creation time."""
|
||||
"""Write to the stdout captured at spinner creation time.
|
||||
|
||||
If a print_fn was supplied at construction, all output is routed through
|
||||
it instead — allowing callers to silence the spinner with a no-op lambda.
|
||||
"""
|
||||
if self._print_fn is not None:
|
||||
try:
|
||||
self._print_fn(text)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
try:
|
||||
self._out.write(text + end)
|
||||
if flush:
|
||||
|
||||
@@ -688,6 +688,12 @@ display:
|
||||
# Toggle at runtime with /verbose in the CLI
|
||||
tool_progress: all
|
||||
|
||||
# What Enter does when Hermes is already busy in the CLI.
|
||||
# interrupt: Interrupt the current run and redirect Hermes (default)
|
||||
# queue: Queue your message for the next turn
|
||||
# Ctrl+C always interrupts regardless of this setting.
|
||||
busy_input_mode: interrupt
|
||||
|
||||
# Background process notifications (gateway/messaging only).
|
||||
# Controls how chatty the process watcher is when you use
|
||||
# terminal(background=true, check_interval=...) from Telegram/Discord/etc.
|
||||
|
||||
157
cli.py
157
cli.py
@@ -205,6 +205,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"resume_display": "full",
|
||||
"show_reasoning": False,
|
||||
"streaming": True,
|
||||
"busy_input_mode": "interrupt",
|
||||
|
||||
"skin": "default",
|
||||
},
|
||||
@@ -1035,13 +1036,18 @@ class HermesCLI:
|
||||
self.config = CLI_CONFIG
|
||||
self.compact = compact if compact is not None else CLI_CONFIG["display"].get("compact", False)
|
||||
# tool_progress: "off", "new", "all", "verbose" (from config.yaml display section)
|
||||
self.tool_progress_mode = CLI_CONFIG["display"].get("tool_progress", "all")
|
||||
# YAML 1.1 parses bare `off` as boolean False — normalise to string.
|
||||
_raw_tp = CLI_CONFIG["display"].get("tool_progress", "all")
|
||||
self.tool_progress_mode = "off" if _raw_tp is False else str(_raw_tp)
|
||||
# resume_display: "full" (show history) | "minimal" (one-liner only)
|
||||
self.resume_display = CLI_CONFIG["display"].get("resume_display", "full")
|
||||
# bell_on_complete: play terminal bell (\a) when agent finishes a response
|
||||
self.bell_on_complete = CLI_CONFIG["display"].get("bell_on_complete", False)
|
||||
# show_reasoning: display model thinking/reasoning before the response
|
||||
self.show_reasoning = CLI_CONFIG["display"].get("show_reasoning", False)
|
||||
# busy_input_mode: "interrupt" (Enter interrupts current run) or "queue" (Enter queues for next turn)
|
||||
_bim = CLI_CONFIG["display"].get("busy_input_mode", "interrupt")
|
||||
self.busy_input_mode = "queue" if str(_bim).strip().lower() == "queue" else "interrupt"
|
||||
|
||||
self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose")
|
||||
|
||||
@@ -1329,7 +1335,12 @@ class HermesCLI:
|
||||
def _build_status_bar_text(self, width: Optional[int] = None) -> str:
|
||||
try:
|
||||
snapshot = self._get_status_bar_snapshot()
|
||||
width = width or shutil.get_terminal_size((80, 24)).columns
|
||||
if width is None:
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
width = get_app().output.get_size().columns
|
||||
except Exception:
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
percent = snapshot["context_percent"]
|
||||
percent_label = f"{percent}%" if percent is not None else "--"
|
||||
duration_label = snapshot["duration"]
|
||||
@@ -1359,7 +1370,16 @@ class HermesCLI:
|
||||
return []
|
||||
try:
|
||||
snapshot = self._get_status_bar_snapshot()
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
# Use prompt_toolkit's own terminal width when running inside the
|
||||
# TUI — shutil.get_terminal_size() can return stale or fallback
|
||||
# values (especially on SSH) that differ from what prompt_toolkit
|
||||
# actually renders, causing the fragments to overflow to a second
|
||||
# line and produce duplicated status bar rows over long sessions.
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
width = get_app().output.get_size().columns
|
||||
except Exception:
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
duration_label = snapshot["duration"]
|
||||
|
||||
if width < 52:
|
||||
@@ -2929,6 +2949,82 @@ class HermesCLI:
|
||||
if not silent:
|
||||
print("(^_^)v New session started!")
|
||||
|
||||
def _handle_resume_command(self, cmd_original: str) -> None:
|
||||
"""Handle /resume <session_id_or_title> — switch to a previous session mid-conversation."""
|
||||
parts = cmd_original.split(None, 1)
|
||||
target = parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
if not target:
|
||||
_cprint(" Usage: /resume <session_id_or_title>")
|
||||
_cprint(" Tip: Use /history or `hermes sessions list` to find sessions.")
|
||||
return
|
||||
|
||||
if not self._session_db:
|
||||
_cprint(" Session database not available.")
|
||||
return
|
||||
|
||||
# Resolve title or ID
|
||||
from hermes_cli.main import _resolve_session_by_name_or_id
|
||||
resolved = _resolve_session_by_name_or_id(target)
|
||||
target_id = resolved or target
|
||||
|
||||
session_meta = self._session_db.get_session(target_id)
|
||||
if not session_meta:
|
||||
_cprint(f" Session not found: {target}")
|
||||
_cprint(" Use /history or `hermes sessions list` to see available sessions.")
|
||||
return
|
||||
|
||||
if target_id == self.session_id:
|
||||
_cprint(" Already on that session.")
|
||||
return
|
||||
|
||||
# End current session
|
||||
try:
|
||||
self._session_db.end_session(self.session_id, "resumed_other")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Switch to the target session
|
||||
self.session_id = target_id
|
||||
self._resumed = True
|
||||
self._pending_title = None
|
||||
|
||||
# Load conversation history
|
||||
restored = self._session_db.get_messages_as_conversation(target_id)
|
||||
self.conversation_history = restored or []
|
||||
|
||||
# Re-open the target session so it's not marked as ended
|
||||
try:
|
||||
self._session_db.reopen_session(target_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Sync the agent if already initialised
|
||||
if self.agent:
|
||||
self.agent.session_id = target_id
|
||||
self.agent.reset_session_state()
|
||||
if hasattr(self.agent, "_last_flushed_db_idx"):
|
||||
self.agent._last_flushed_db_idx = len(self.conversation_history)
|
||||
if hasattr(self.agent, "_todo_store"):
|
||||
try:
|
||||
from tools.todo_tool import TodoStore
|
||||
self.agent._todo_store = TodoStore()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(self.agent, "_invalidate_system_prompt"):
|
||||
self.agent._invalidate_system_prompt()
|
||||
|
||||
title_part = f" \"{session_meta['title']}\"" if session_meta.get("title") else ""
|
||||
msg_count = len([m for m in self.conversation_history if m.get("role") == "user"])
|
||||
if self.conversation_history:
|
||||
_cprint(
|
||||
f" ↻ Resumed session {target_id}{title_part}"
|
||||
f" ({msg_count} user message{'s' if msg_count != 1 else ''},"
|
||||
f" {len(self.conversation_history)} total)"
|
||||
)
|
||||
else:
|
||||
_cprint(f" ↻ Resumed session {target_id}{title_part} — no messages, starting fresh.")
|
||||
|
||||
def reset_conversation(self):
|
||||
"""Reset the conversation by starting a new session."""
|
||||
self.new_session()
|
||||
@@ -3647,6 +3743,8 @@ class HermesCLI:
|
||||
_cprint(" Session database not available.")
|
||||
elif canonical == "new":
|
||||
self.new_session()
|
||||
elif canonical == "resume":
|
||||
self._handle_resume_command(cmd_original)
|
||||
elif canonical == "provider":
|
||||
self._show_model_and_providers()
|
||||
elif canonical == "prompt":
|
||||
@@ -3722,17 +3820,17 @@ class HermesCLI:
|
||||
elif canonical == "background":
|
||||
self._handle_background_command(cmd_original)
|
||||
elif canonical == "queue":
|
||||
if not self._agent_running:
|
||||
_cprint(" /queue only works while Hermes is busy. Just type your message normally.")
|
||||
# Extract prompt after "/queue " or "/q "
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /queue <prompt>")
|
||||
else:
|
||||
# Extract prompt after "/queue " or "/q "
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /queue <prompt>")
|
||||
else:
|
||||
self._pending_input.put(payload)
|
||||
self._pending_input.put(payload)
|
||||
if self._agent_running:
|
||||
_cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
else:
|
||||
_cprint(f" Queued: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "skin":
|
||||
self._handle_skin_command(cmd_original)
|
||||
elif canonical == "voice":
|
||||
@@ -6112,16 +6210,22 @@ class HermesCLI:
|
||||
# Bundle text + images as a tuple when images are present
|
||||
payload = (text, images) if images else text
|
||||
if self._agent_running and not (text and text.startswith("/")):
|
||||
self._interrupt_queue.put(payload)
|
||||
# Debug: log to file when message enters interrupt queue
|
||||
try:
|
||||
_dbg = _hermes_home / "interrupt_debug.log"
|
||||
with open(_dbg, "a") as _f:
|
||||
import time as _t
|
||||
_f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, "
|
||||
f"agent_running={self._agent_running}\n")
|
||||
except Exception:
|
||||
pass
|
||||
if self.busy_input_mode == "queue":
|
||||
# Queue for the next turn instead of interrupting
|
||||
self._pending_input.put(payload)
|
||||
preview = text if text else f"[{len(images)} image{'s' if len(images) != 1 else ''} attached]"
|
||||
_cprint(f" Queued for the next turn: {preview[:80]}{'...' if len(preview) > 80 else ''}")
|
||||
else:
|
||||
self._interrupt_queue.put(payload)
|
||||
# Debug: log to file when message enters interrupt queue
|
||||
try:
|
||||
_dbg = _hermes_home / "interrupt_debug.log"
|
||||
with open(_dbg, "a") as _f:
|
||||
import time as _t
|
||||
_f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, "
|
||||
f"agent_running={self._agent_running}\n")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self._pending_input.put(payload)
|
||||
event.app.current_buffer.reset(append_to_history=True)
|
||||
@@ -6894,6 +6998,15 @@ class HermesCLI:
|
||||
Window(
|
||||
content=FormattedTextControl(lambda: cli_ref._get_status_bar_fragments()),
|
||||
height=1,
|
||||
# Prevent fragments that overflow the terminal width from
|
||||
# wrapping onto a second line, which causes the status bar to
|
||||
# appear duplicated (one full + one partial row) during long
|
||||
# sessions, especially on SSH where shutil.get_terminal_size
|
||||
# may return stale values. _get_status_bar_fragments now reads
|
||||
# width from prompt_toolkit's own output object, so fragments
|
||||
# will always fit; wrap_lines=False is the belt-and-suspenders
|
||||
# guard against any future width mismatch.
|
||||
wrap_lines=False,
|
||||
),
|
||||
filter=Condition(lambda: cli_ref._status_bar_visible),
|
||||
)
|
||||
|
||||
@@ -366,14 +366,20 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
Create an AIAgent instance using the gateway's runtime config.
|
||||
|
||||
Uses _resolve_runtime_agent_kwargs() to pick up model, api_key,
|
||||
base_url, etc. from config.yaml / env vars.
|
||||
base_url, etc. from config.yaml / env vars. Toolsets are resolved
|
||||
from config.yaml platform_toolsets.api_server (same as all other
|
||||
gateway platforms), falling back to the hermes-api-server default.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model
|
||||
from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model, _load_gateway_config
|
||||
from hermes_cli.tools_config import _get_platform_tools
|
||||
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
model = _resolve_gateway_model()
|
||||
|
||||
user_config = _load_gateway_config()
|
||||
enabled_toolsets = sorted(_get_platform_tools(user_config, "api_server"))
|
||||
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
|
||||
agent = AIAgent(
|
||||
@@ -383,7 +389,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
quiet_mode=True,
|
||||
verbose_logging=False,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt or None,
|
||||
enabled_toolsets=["hermes-api-server"],
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
session_id=session_id,
|
||||
platform="api_server",
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
|
||||
@@ -8,6 +8,7 @@ and implement the required methods.
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -71,31 +72,51 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
|
||||
return str(filepath)
|
||||
|
||||
|
||||
async def cache_image_from_url(url: str, ext: str = ".jpg") -> str:
|
||||
async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> str:
|
||||
"""
|
||||
Download an image from a URL and save it to the local cache.
|
||||
|
||||
Uses httpx for async download with a reasonable timeout.
|
||||
Retries on transient failures (timeouts, 429, 5xx) with exponential
|
||||
backoff so a single slow CDN response doesn't lose the media.
|
||||
|
||||
Args:
|
||||
url: The HTTP/HTTPS URL to download from.
|
||||
ext: File extension including the dot (e.g. ".jpg", ".png").
|
||||
retries: Number of retry attempts on transient failures.
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
"""
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging as _logging
|
||||
_log = _logging.getLogger(__name__)
|
||||
|
||||
last_exc = None
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < retries:
|
||||
wait = 1.5 * (attempt + 1)
|
||||
_log.debug("Media cache retry %d/%d for %s (%.1fs): %s",
|
||||
attempt + 1, retries, url[:80], wait, exc)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
|
||||
def cleanup_image_cache(max_age_hours: int = 24) -> int:
|
||||
@@ -329,6 +350,24 @@ class SendResult:
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
raw_response: Any = None
|
||||
retryable: bool = False # True for transient errors (network, timeout) — base will retry automatically
|
||||
|
||||
|
||||
# Error substrings that indicate a transient network failure worth retrying
|
||||
_RETRYABLE_ERROR_PATTERNS = (
|
||||
"connecterror",
|
||||
"connectionerror",
|
||||
"connectionreset",
|
||||
"connectionrefused",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"network",
|
||||
"broken pipe",
|
||||
"remotedisconnected",
|
||||
"eoferror",
|
||||
"readtimeout",
|
||||
"writetimeout",
|
||||
)
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
@@ -833,6 +872,91 @@ class BasePlatformAdapter(ABC):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_error(error: Optional[str]) -> bool:
|
||||
"""Return True if the error string looks like a transient network failure."""
|
||||
if not error:
|
||||
return False
|
||||
lowered = error.lower()
|
||||
return any(pat in lowered for pat in _RETRYABLE_ERROR_PATTERNS)
|
||||
|
||||
async def _send_with_retry(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Any = None,
|
||||
max_retries: int = 2,
|
||||
base_delay: float = 2.0,
|
||||
) -> "SendResult":
|
||||
"""
|
||||
Send a message with automatic retry for transient network errors.
|
||||
|
||||
On permanent failures (e.g. formatting / permission errors) falls back
|
||||
to a plain-text version before giving up. If all attempts fail due to
|
||||
network errors, sends the user a brief delivery-failure notice so they
|
||||
know to retry rather than waiting indefinitely.
|
||||
"""
|
||||
|
||||
result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if result.success:
|
||||
return result
|
||||
|
||||
error_str = result.error or ""
|
||||
is_network = result.retryable or self._is_retryable_error(error_str)
|
||||
|
||||
if is_network:
|
||||
# Retry with exponential backoff for transient errors
|
||||
for attempt in range(1, max_retries + 1):
|
||||
delay = base_delay * (2 ** (attempt - 1)) + random.uniform(0, 1)
|
||||
logger.warning(
|
||||
"[%s] Send failed (attempt %d/%d, retrying in %.1fs): %s",
|
||||
self.name, attempt, max_retries, delay, error_str,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
if result.success:
|
||||
logger.info("[%s] Send succeeded on retry %d", self.name, attempt)
|
||||
return result
|
||||
error_str = result.error or ""
|
||||
if not (result.retryable or self._is_retryable_error(error_str)):
|
||||
break # error switched to non-transient — fall through to plain-text fallback
|
||||
else:
|
||||
# All retries exhausted (loop completed without break) — notify user
|
||||
logger.error("[%s] Failed to deliver response after %d retries: %s", self.name, max_retries, error_str)
|
||||
notice = (
|
||||
"\u26a0\ufe0f Message delivery failed after multiple attempts. "
|
||||
"Please try again \u2014 your request was processed but the response could not be sent."
|
||||
)
|
||||
try:
|
||||
await self.send(chat_id=chat_id, content=notice, reply_to=reply_to, metadata=metadata)
|
||||
except Exception as notify_err:
|
||||
logger.debug("[%s] Could not send delivery-failure notice: %s", self.name, notify_err)
|
||||
return result
|
||||
|
||||
# Non-network / post-retry formatting failure: try plain text as fallback
|
||||
logger.warning("[%s] Send failed: %s — trying plain-text fallback", self.name, error_str)
|
||||
fallback_result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{content[:3500]}",
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
logger.error("[%s] Fallback send also failed: %s", self.name, fallback_result.error)
|
||||
return fallback_result
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
@@ -982,26 +1106,13 @@ class BasePlatformAdapter(ABC):
|
||||
# Send the text portion
|
||||
if text_content:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
result = await self._send_with_retry(
|
||||
chat_id=event.source.chat_id,
|
||||
content=text_content,
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
|
||||
@@ -2096,6 +2096,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if pending_text_injection:
|
||||
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection
|
||||
|
||||
# Defense-in-depth: prevent empty user messages from entering session
|
||||
# (can happen when user sends @mention-only with no other text)
|
||||
if not event_text or not event_text.strip():
|
||||
event_text = "(The user sent a message with no text content)"
|
||||
|
||||
event = MessageEvent(
|
||||
text=event_text,
|
||||
message_type=msg_type,
|
||||
|
||||
@@ -551,9 +551,20 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
"""Continuously sync with the homeserver."""
|
||||
import nio
|
||||
|
||||
while not self._closing:
|
||||
try:
|
||||
await self._client.sync(timeout=30000)
|
||||
resp = await self._client.sync(timeout=30000)
|
||||
if isinstance(resp, nio.SyncError):
|
||||
if self._closing:
|
||||
return
|
||||
logger.warning(
|
||||
"Matrix: sync returned %s: %s — retrying in 5s",
|
||||
type(resp).__name__,
|
||||
getattr(resp, "message", resp),
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
|
||||
@@ -407,18 +407,38 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
kind: str = "file",
|
||||
) -> SendResult:
|
||||
"""Download a URL and upload it as a file attachment."""
|
||||
import asyncio
|
||||
import aiohttp
|
||||
try:
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status >= 400:
|
||||
# Fall back to sending the URL as text.
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
file_data = await resp.read()
|
||||
ct = resp.content_type or "application/octet-stream"
|
||||
# Derive filename from URL.
|
||||
fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png"
|
||||
except Exception as exc:
|
||||
logger.warning("Mattermost: failed to download %s: %s", url, exc)
|
||||
|
||||
last_exc = None
|
||||
file_data = None
|
||||
ct = "application/octet-stream"
|
||||
fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png"
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status >= 500 or resp.status == 429:
|
||||
if attempt < 2:
|
||||
logger.debug("Mattermost download retry %d/2 for %s (status %d)",
|
||||
attempt + 1, url[:80], resp.status)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
if resp.status >= 400:
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
file_data = await resp.read()
|
||||
ct = resp.content_type or "application/octet-stream"
|
||||
break
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||
last_exc = exc
|
||||
if attempt < 2:
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
logger.warning("Mattermost: failed to download %s after %d attempts: %s", url, attempt + 1, exc)
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
if file_data is None:
|
||||
logger.warning("Mattermost: download returned no data for %s", url)
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
file_id = await self._upload_file(chat_id, file_data, fname, ct)
|
||||
|
||||
@@ -279,6 +279,12 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# SSE keepalive comments (":") prove the connection
|
||||
# is alive — update activity so the health monitor
|
||||
# doesn't report false idle warnings.
|
||||
if line.startswith(":"):
|
||||
self._last_sse_activity = time.time()
|
||||
continue
|
||||
# Parse SSE data lines
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
|
||||
@@ -819,33 +819,65 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _download_slack_file(self, url: str, ext: str, audio: bool = False) -> str:
|
||||
"""Download a Slack file using the bot token for auth."""
|
||||
"""Download a Slack file using the bot token for auth, with retry."""
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
last_exc = None
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < 2:
|
||||
logger.debug("Slack file download retry %d/2 for %s: %s",
|
||||
attempt + 1, url[:80], exc)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str) -> bytes:
|
||||
"""Download a Slack file and return raw bytes."""
|
||||
"""Download a Slack file and return raw bytes, with retry."""
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
last_exc = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < 2:
|
||||
logger.debug("Slack file download retry %d/2 for %s: %s",
|
||||
attempt + 1, url[:80], exc)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
|
||||
146
gateway/run.py
146
gateway/run.py
@@ -573,6 +573,10 @@ class GatewayRunner:
|
||||
session_id=old_session_id,
|
||||
honcho_session_key=honcho_session_key,
|
||||
)
|
||||
# Fully silence the flush agent — quiet_mode only suppresses init
|
||||
# messages; tool call output still leaks to the terminal through
|
||||
# _safe_print → _print_fn. Set a no-op to prevent that.
|
||||
tmp_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
# Build conversation history from transcript
|
||||
msgs = [
|
||||
@@ -954,12 +958,20 @@ class GatewayRunner:
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "EMAIL_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
os.getenv(v, "").lower() in ("true", "1", "yes")
|
||||
for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS",
|
||||
"WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS",
|
||||
"SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS")
|
||||
)
|
||||
if not _any_allowlist and not _allow_all:
|
||||
logger.warning(
|
||||
"No user allowlists configured. All unauthorized users will be denied. "
|
||||
@@ -1970,6 +1982,12 @@ class GatewayRunner:
|
||||
f"Use /resume to browse and restore a previous session.\n"
|
||||
f"Adjust reset timing in config.yaml under session_reset."
|
||||
)
|
||||
try:
|
||||
session_info = self._format_session_info()
|
||||
if session_info:
|
||||
notice = f"{notice}\n\n{session_info}"
|
||||
except Exception:
|
||||
pass
|
||||
await adapter.send(
|
||||
source.chat_id, notice,
|
||||
metadata=getattr(event, 'metadata', None),
|
||||
@@ -2175,6 +2193,7 @@ class GatewayRunner:
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
_hyg_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
_compressed, _ = await loop.run_in_executor(
|
||||
@@ -2736,6 +2755,85 @@ class GatewayRunner:
|
||||
# Clear session env
|
||||
self._clear_session_env()
|
||||
|
||||
def _format_session_info(self) -> str:
|
||||
"""Resolve current model config and return a formatted info block.
|
||||
|
||||
Surfaces model, provider, context length, and endpoint so gateway
|
||||
users can immediately see if context detection went wrong (e.g.
|
||||
local models falling to the 128K default).
|
||||
"""
|
||||
from agent.model_metadata import get_model_context_length, DEFAULT_FALLBACK_CONTEXT
|
||||
|
||||
model = _resolve_gateway_model()
|
||||
config_context_length = None
|
||||
provider = None
|
||||
base_url = None
|
||||
api_key = None
|
||||
|
||||
try:
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
import yaml as _info_yaml
|
||||
with open(cfg_path, encoding="utf-8") as f:
|
||||
data = _info_yaml.safe_load(f) or {}
|
||||
model_cfg = data.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
raw_ctx = model_cfg.get("context_length")
|
||||
if raw_ctx is not None:
|
||||
try:
|
||||
config_context_length = int(raw_ctx)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
provider = model_cfg.get("provider") or None
|
||||
base_url = model_cfg.get("base_url") or None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Resolve runtime credentials for probing
|
||||
try:
|
||||
runtime = _resolve_runtime_agent_kwargs()
|
||||
provider = provider or runtime.get("provider")
|
||||
base_url = base_url or runtime.get("base_url")
|
||||
api_key = runtime.get("api_key")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
context_length = get_model_context_length(
|
||||
model,
|
||||
base_url=base_url or "",
|
||||
api_key=api_key or "",
|
||||
config_context_length=config_context_length,
|
||||
provider=provider or "",
|
||||
)
|
||||
|
||||
# Format context source hint
|
||||
if config_context_length is not None:
|
||||
ctx_source = "config"
|
||||
elif context_length == DEFAULT_FALLBACK_CONTEXT:
|
||||
ctx_source = "default — set model.context_length in config to override"
|
||||
else:
|
||||
ctx_source = "detected"
|
||||
|
||||
# Format context length for display
|
||||
if context_length >= 1_000_000:
|
||||
ctx_display = f"{context_length / 1_000_000:.1f}M"
|
||||
elif context_length >= 1_000:
|
||||
ctx_display = f"{context_length // 1_000}K"
|
||||
else:
|
||||
ctx_display = str(context_length)
|
||||
|
||||
lines = [
|
||||
f"◆ Model: `{model}`",
|
||||
f"◆ Provider: {provider or 'openrouter'}",
|
||||
f"◆ Context: {ctx_display} tokens ({ctx_source})",
|
||||
]
|
||||
|
||||
# Show endpoint for local/custom setups
|
||||
if base_url and ("localhost" in base_url or "127.0.0.1" in base_url or "0.0.0.0" in base_url):
|
||||
lines.append(f"◆ Endpoint: {base_url}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_reset_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /new or /reset command."""
|
||||
source = event.source
|
||||
@@ -2776,12 +2874,22 @@ class GatewayRunner:
|
||||
"session_key": session_key,
|
||||
})
|
||||
|
||||
# Resolve session config info to surface to the user
|
||||
try:
|
||||
session_info = self._format_session_info()
|
||||
except Exception:
|
||||
session_info = ""
|
||||
|
||||
if new_entry:
|
||||
return "✨ Session reset! I've started fresh with no memory of our previous conversation."
|
||||
header = "✨ Session reset! Starting fresh."
|
||||
else:
|
||||
# No existing session, just create one
|
||||
self.session_store.get_or_create_session(source, force_new=True)
|
||||
return "✨ New session started!"
|
||||
header = "✨ New session started!"
|
||||
|
||||
if session_info:
|
||||
return f"{header}\n\n{session_info}"
|
||||
return header
|
||||
|
||||
async def _handle_status_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /status command."""
|
||||
@@ -3885,6 +3993,7 @@ class GatewayRunner:
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
tmp_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
compressed, _ = await loop.run_in_executor(
|
||||
@@ -4799,9 +4908,14 @@ class GatewayRunner:
|
||||
enabled_toolsets = sorted(_get_platform_tools(user_config, platform_key))
|
||||
|
||||
# Tool progress mode from config.yaml: "all", "new", "verbose", "off"
|
||||
# Falls back to env vars for backward compatibility
|
||||
# Falls back to env vars for backward compatibility.
|
||||
# YAML 1.1 parses bare `off` as boolean False — normalise before
|
||||
# the `or` chain so it doesn't silently fall through to "all".
|
||||
_raw_tp = user_config.get("display", {}).get("tool_progress")
|
||||
if _raw_tp is False:
|
||||
_raw_tp = "off"
|
||||
progress_mode = (
|
||||
user_config.get("display", {}).get("tool_progress")
|
||||
_raw_tp
|
||||
or os.getenv("HERMES_TOOL_PROGRESS_MODE")
|
||||
or "all"
|
||||
)
|
||||
@@ -5128,7 +5242,25 @@ class GatewayRunner:
|
||||
agent.stream_delta_callback = _stream_delta_cb
|
||||
agent.status_callback = _status_callback_sync
|
||||
agent.reasoning_config = reasoning_config
|
||||
|
||||
|
||||
# Background review delivery — send "💾 Memory updated" etc. to user
|
||||
def _bg_review_send(message: str) -> None:
|
||||
if not _status_adapter:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
_status_adapter.send(
|
||||
_status_chat_id,
|
||||
message,
|
||||
metadata=_status_thread_metadata,
|
||||
),
|
||||
_loop_for_step,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("background_review_callback error: %s", _e)
|
||||
|
||||
agent.background_review_callback = _bg_review_send
|
||||
|
||||
# Store agent reference for interrupt support
|
||||
agent_holder[0] = agent
|
||||
# Capture the full tool definitions for transcript logging
|
||||
|
||||
@@ -762,14 +762,16 @@ class SessionStore:
|
||||
if session_key in self._entries:
|
||||
entry = self._entries[session_key]
|
||||
entry.updated_at = _now()
|
||||
entry.input_tokens += input_tokens
|
||||
entry.output_tokens += output_tokens
|
||||
entry.cache_read_tokens += cache_read_tokens
|
||||
entry.cache_write_tokens += cache_write_tokens
|
||||
# Direct assignment — the gateway receives cumulative totals
|
||||
# from the cached agent, not per-call deltas.
|
||||
entry.input_tokens = input_tokens
|
||||
entry.output_tokens = output_tokens
|
||||
entry.cache_read_tokens = cache_read_tokens
|
||||
entry.cache_write_tokens = cache_write_tokens
|
||||
if last_prompt_tokens is not None:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
if estimated_cost_usd is not None:
|
||||
entry.estimated_cost_usd += estimated_cost_usd
|
||||
entry.estimated_cost_usd = estimated_cost_usd
|
||||
if cost_status:
|
||||
entry.cost_status = cost_status
|
||||
entry.total_tokens = (
|
||||
@@ -783,7 +785,7 @@ class SessionStore:
|
||||
|
||||
if self._db and db_session_id:
|
||||
try:
|
||||
self._db.update_token_counts(
|
||||
self._db.set_token_counts(
|
||||
db_session_id,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
@@ -795,6 +797,7 @@ class SessionStore:
|
||||
billing_provider=provider,
|
||||
billing_base_url=base_url,
|
||||
model=model,
|
||||
absolute=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
@@ -955,13 +958,17 @@ class SessionStore:
|
||||
try:
|
||||
self._db.clear_messages(session_id)
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
self._db.append_message(
|
||||
session_id=session_id,
|
||||
role=msg.get("role", "unknown"),
|
||||
role=role,
|
||||
content=msg.get("content"),
|
||||
tool_name=msg.get("tool_name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
reasoning=msg.get("reasoning") if role == "assistant" else None,
|
||||
reasoning_details=msg.get("reasoning_details") if role == "assistant" else None,
|
||||
codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
||||
|
||||
@@ -264,6 +264,7 @@ DEFAULT_CONFIG = {
|
||||
"compact": False,
|
||||
"personality": "kawaii",
|
||||
"resume_display": "full",
|
||||
"busy_input_mode": "interrupt",
|
||||
"bell_on_complete": False,
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
|
||||
@@ -2968,6 +2968,95 @@ def setup_tools(config: dict, first_install: bool = False):
|
||||
tools_command(first_install=first_install, config=config)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Post-Migration Section Skip Logic
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]:
|
||||
"""Return a short summary if a setup section is already configured, else None.
|
||||
|
||||
Used after OpenClaw migration to detect which sections can be skipped.
|
||||
``get_env_value`` is the module-level import from hermes_cli.config
|
||||
so that test patches on ``setup_mod.get_env_value`` take effect.
|
||||
"""
|
||||
if section_key == "model":
|
||||
has_key = bool(
|
||||
get_env_value("OPENROUTER_API_KEY")
|
||||
or get_env_value("OPENAI_API_KEY")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
)
|
||||
if not has_key:
|
||||
# Check for OAuth providers
|
||||
try:
|
||||
from hermes_cli.auth import get_active_provider
|
||||
if get_active_provider():
|
||||
has_key = True
|
||||
except Exception:
|
||||
pass
|
||||
if not has_key:
|
||||
return None
|
||||
model = config.get("model")
|
||||
if isinstance(model, str) and model.strip():
|
||||
return model.strip()
|
||||
if isinstance(model, dict):
|
||||
return str(model.get("default") or model.get("model") or "configured")
|
||||
return "configured"
|
||||
|
||||
elif section_key == "terminal":
|
||||
backend = config.get("terminal", {}).get("backend", "local")
|
||||
return f"backend: {backend}"
|
||||
|
||||
elif section_key == "agent":
|
||||
max_turns = config.get("agent", {}).get("max_turns", 90)
|
||||
return f"max turns: {max_turns}"
|
||||
|
||||
elif section_key == "gateway":
|
||||
platforms = []
|
||||
if get_env_value("TELEGRAM_BOT_TOKEN"):
|
||||
platforms.append("Telegram")
|
||||
if get_env_value("DISCORD_BOT_TOKEN"):
|
||||
platforms.append("Discord")
|
||||
if get_env_value("SLACK_BOT_TOKEN"):
|
||||
platforms.append("Slack")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("SIGNAL_ACCOUNT"):
|
||||
platforms.append("Signal")
|
||||
if platforms:
|
||||
return ", ".join(platforms)
|
||||
return None # No platforms configured — section must run
|
||||
|
||||
elif section_key == "tools":
|
||||
tools = []
|
||||
if get_env_value("ELEVENLABS_API_KEY"):
|
||||
tools.append("TTS/ElevenLabs")
|
||||
if get_env_value("BROWSERBASE_API_KEY"):
|
||||
tools.append("Browser")
|
||||
if get_env_value("FIRECRAWL_API_KEY"):
|
||||
tools.append("Firecrawl")
|
||||
if tools:
|
||||
return ", ".join(tools)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _skip_configured_section(
|
||||
config: dict, section_key: str, label: str
|
||||
) -> bool:
|
||||
"""Show an already-configured section summary and offer to skip.
|
||||
|
||||
Returns True if the user chose to skip, False if the section should run.
|
||||
"""
|
||||
summary = _get_section_config_summary(config, section_key)
|
||||
if not summary:
|
||||
return False
|
||||
print()
|
||||
print_success(f" {label}: {summary}")
|
||||
return not prompt_yes_no(f" Reconfigure {label.lower()}?", default=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenClaw Migration
|
||||
# =============================================================================
|
||||
@@ -3039,7 +3128,7 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=True,
|
||||
workspace_target=None,
|
||||
overwrite=False,
|
||||
overwrite=True,
|
||||
migrate_secrets=True,
|
||||
output_dir=None,
|
||||
selected_options=selected,
|
||||
@@ -3195,6 +3284,8 @@ def run_setup_wizard(args):
|
||||
)
|
||||
)
|
||||
|
||||
migration_ran = False
|
||||
|
||||
if is_existing:
|
||||
# ── Returning User Menu ──
|
||||
print()
|
||||
@@ -3264,7 +3355,8 @@ def run_setup_wizard(args):
|
||||
return
|
||||
|
||||
# Offer OpenClaw migration before configuration begins
|
||||
if _offer_openclaw_migration(hermes_home):
|
||||
migration_ran = _offer_openclaw_migration(hermes_home)
|
||||
if migration_ran:
|
||||
# Reload config in case migration wrote to it
|
||||
config = load_config()
|
||||
|
||||
@@ -3277,20 +3369,31 @@ def run_setup_wizard(args):
|
||||
print()
|
||||
print_info("You can edit these files directly or use 'hermes config edit'")
|
||||
|
||||
if migration_ran:
|
||||
print()
|
||||
print_info("Settings were imported from OpenClaw.")
|
||||
print_info("Each section below will show what was imported — press Enter to keep,")
|
||||
print_info("or choose to reconfigure if needed.")
|
||||
|
||||
# Section 1: Model & Provider
|
||||
setup_model_provider(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "model", "Model & Provider")):
|
||||
setup_model_provider(config)
|
||||
|
||||
# Section 2: Terminal Backend
|
||||
setup_terminal_backend(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "terminal", "Terminal Backend")):
|
||||
setup_terminal_backend(config)
|
||||
|
||||
# Section 3: Agent Settings
|
||||
setup_agent_settings(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "agent", "Agent Settings")):
|
||||
setup_agent_settings(config)
|
||||
|
||||
# Section 4: Messaging Platforms
|
||||
setup_gateway(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "gateway", "Messaging Platforms")):
|
||||
setup_gateway(config)
|
||||
|
||||
# Section 5: Tools
|
||||
setup_tools(config, first_install=not is_existing)
|
||||
if not (migration_ran and _skip_configured_section(config, "tools", "Tools")):
|
||||
setup_tools(config, first_install=not is_existing)
|
||||
|
||||
# Save and show summary
|
||||
save_config(config)
|
||||
|
||||
116
hermes_state.py
116
hermes_state.py
@@ -284,6 +284,15 @@ class SessionDB:
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def reopen_session(self, session_id: str) -> None:
|
||||
"""Clear ended_at/end_reason so a session can be resumed."""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET ended_at = NULL, end_reason = NULL WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||
"""Store the full assembled system prompt snapshot."""
|
||||
with self._lock:
|
||||
@@ -310,11 +319,39 @@ class SessionDB:
|
||||
billing_provider: Optional[str] = None,
|
||||
billing_base_url: Optional[str] = None,
|
||||
billing_mode: Optional[str] = None,
|
||||
absolute: bool = False,
|
||||
) -> None:
|
||||
"""Increment token counters and backfill model if not already set."""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
"""Update token counters and backfill model if not already set.
|
||||
|
||||
When *absolute* is False (default), values are **incremented** — use
|
||||
this for per-API-call deltas (CLI path).
|
||||
|
||||
When *absolute* is True, values are **set directly** — use this when
|
||||
the caller already holds cumulative totals (gateway path, where the
|
||||
cached agent accumulates across messages).
|
||||
"""
|
||||
if absolute:
|
||||
sql = """UPDATE sessions SET
|
||||
input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = COALESCE(?, 0),
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE ?
|
||||
END,
|
||||
cost_status = COALESCE(?, cost_status),
|
||||
cost_source = COALESCE(?, cost_source),
|
||||
pricing_version = COALESCE(?, pricing_version),
|
||||
billing_provider = COALESCE(billing_provider, ?),
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?"""
|
||||
else:
|
||||
sql = """UPDATE sessions SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
cache_read_tokens = cache_read_tokens + ?,
|
||||
@@ -332,7 +369,10 @@ class SessionDB:
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?""",
|
||||
WHERE id = ?"""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
sql,
|
||||
(
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
@@ -375,6 +415,72 @@ class SessionDB:
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def set_token_counts(
|
||||
self,
|
||||
session_id: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
model: str = None,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
estimated_cost_usd: Optional[float] = None,
|
||||
actual_cost_usd: Optional[float] = None,
|
||||
cost_status: Optional[str] = None,
|
||||
cost_source: Optional[str] = None,
|
||||
pricing_version: Optional[str] = None,
|
||||
billing_provider: Optional[str] = None,
|
||||
billing_base_url: Optional[str] = None,
|
||||
billing_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Set token counters to absolute values (not increment).
|
||||
|
||||
Use this when the caller provides cumulative totals from a completed
|
||||
conversation run (e.g. the gateway, where the cached agent's
|
||||
session_prompt_tokens already reflects the running total).
|
||||
"""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = ?,
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE ?
|
||||
END,
|
||||
cost_status = COALESCE(?, cost_status),
|
||||
cost_source = COALESCE(?, cost_source),
|
||||
pricing_version = COALESCE(?, pricing_version),
|
||||
billing_provider = COALESCE(billing_provider, ?),
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?""",
|
||||
(
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
actual_cost_usd,
|
||||
cost_status,
|
||||
cost_source,
|
||||
pricing_version,
|
||||
billing_provider,
|
||||
billing_base_url,
|
||||
billing_mode,
|
||||
model,
|
||||
session_id,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a session by ID."""
|
||||
with self._lock:
|
||||
|
||||
@@ -55,7 +55,7 @@ honcho = ["honcho-ai>=2.0.1,<3"]
|
||||
mcp = ["mcp>=1.2.0,<2"]
|
||||
homeassistant = ["aiohttp>=3.9.0,<4"]
|
||||
sms = ["aiohttp>=3.9.0,<4"]
|
||||
acp = ["agent-client-protocol>=0.8.1,<1.0"]
|
||||
acp = ["agent-client-protocol>=0.8.1,<0.9"]
|
||||
dingtalk = ["dingtalk-stream>=0.1.0,<1"]
|
||||
rl = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git",
|
||||
|
||||
53
run_agent.py
53
run_agent.py
@@ -486,6 +486,7 @@ class AIAgent:
|
||||
# instead of going directly to stdout where patch_stdout's StdoutProxy
|
||||
# would mangle the escape sequences. None = use builtins.print.
|
||||
self._print_fn = None
|
||||
self.background_review_callback = None # Optional sync callback for gateway delivery
|
||||
self.skip_context_files = skip_context_files
|
||||
self.pass_session_id = pass_session_id
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
@@ -1525,6 +1526,12 @@ class AIAgent:
|
||||
if actions:
|
||||
summary = " · ".join(dict.fromkeys(actions))
|
||||
self._safe_print(f" 💾 {summary}")
|
||||
_bg_cb = self.background_review_callback
|
||||
if _bg_cb:
|
||||
try:
|
||||
_bg_cb(f"💾 {summary}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Background memory/skill review failed: %s", e)
|
||||
@@ -2048,6 +2055,23 @@ class AIAgent:
|
||||
msg["content"] = self._clean_session_content(msg["content"])
|
||||
cleaned.append(msg)
|
||||
|
||||
# Guard: never overwrite a larger session log with fewer messages.
|
||||
# This protects against data loss when --resume loads a session whose
|
||||
# messages weren't fully written to SQLite — the resumed agent starts
|
||||
# with partial history and would otherwise clobber the full JSON log.
|
||||
if self.session_log_file.exists():
|
||||
try:
|
||||
existing = json.loads(self.session_log_file.read_text(encoding="utf-8"))
|
||||
existing_count = existing.get("message_count", len(existing.get("messages", [])))
|
||||
if existing_count > len(cleaned):
|
||||
logging.debug(
|
||||
"Skipping session log overwrite: existing has %d messages, current has %d",
|
||||
existing_count, len(cleaned),
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass # corrupted existing file — allow the overwrite
|
||||
|
||||
entry = {
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
@@ -4127,6 +4151,25 @@ class AIAgent:
|
||||
or is_native_anthropic
|
||||
)
|
||||
|
||||
# Update context compressor limits for the fallback model.
|
||||
# Without this, compression decisions use the primary model's
|
||||
# context window (e.g. 200K) instead of the fallback's (e.g. 32K),
|
||||
# causing oversized sessions to overflow the fallback.
|
||||
if hasattr(self, 'context_compressor') and self.context_compressor:
|
||||
from agent.model_metadata import get_model_context_length
|
||||
fb_context_length = get_model_context_length(
|
||||
self.model, base_url=self.base_url,
|
||||
api_key=self.api_key, provider=self.provider,
|
||||
)
|
||||
self.context_compressor.model = self.model
|
||||
self.context_compressor.base_url = self.base_url
|
||||
self.context_compressor.api_key = self.api_key
|
||||
self.context_compressor.provider = self.provider
|
||||
self.context_compressor.context_length = fb_context_length
|
||||
self.context_compressor.threshold_tokens = int(
|
||||
fb_context_length * self.context_compressor.threshold_percent
|
||||
)
|
||||
|
||||
self._emit_status(
|
||||
f"🔄 Primary model failed — switching to fallback: "
|
||||
f"{fb_model} via {fb_provider}"
|
||||
@@ -5080,7 +5123,7 @@ class AIAgent:
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots')
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
|
||||
try:
|
||||
@@ -5121,7 +5164,7 @@ class AIAgent:
|
||||
# Print cute message per tool
|
||||
if self.quiet_mode:
|
||||
cute_msg = _get_cute_tool_message_impl(name, args, tool_duration, result=function_result)
|
||||
print(f" {cute_msg}")
|
||||
self._safe_print(f" {cute_msg}")
|
||||
elif not self.quiet_mode:
|
||||
if self.verbose_logging:
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s")
|
||||
@@ -5306,7 +5349,7 @@ class AIAgent:
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots')
|
||||
spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
self._delegate_spinner = spinner
|
||||
_delegate_result = None
|
||||
@@ -5336,7 +5379,7 @@ class AIAgent:
|
||||
preview = _build_tool_preview(function_name, function_args) or function_name
|
||||
if len(preview) > 30:
|
||||
preview = preview[:27] + "..."
|
||||
spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots')
|
||||
spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
_spinner_result = None
|
||||
try:
|
||||
@@ -6019,7 +6062,7 @@ class AIAgent:
|
||||
# Raw KawaiiSpinner only when no streaming consumers
|
||||
# (would conflict with streamed token output)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type, print_fn=self._print_fn)
|
||||
thinking_spinner.start()
|
||||
|
||||
# Log request details if verbose
|
||||
|
||||
46
tests/gateway/test_allowlist_startup_check.py
Normal file
46
tests/gateway/test_allowlist_startup_check.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Tests for the startup allowlist warning check in gateway/run.py."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _would_warn():
|
||||
"""Replicate the startup allowlist warning logic. Returns True if warning fires."""
|
||||
_any_allowlist = any(
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
os.getenv(v, "").lower() in ("true", "1", "yes")
|
||||
for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS",
|
||||
"WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS",
|
||||
"SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS")
|
||||
)
|
||||
return not _any_allowlist and not _allow_all
|
||||
|
||||
|
||||
class TestAllowlistStartupCheck:
|
||||
|
||||
def test_no_config_emits_warning(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert _would_warn() is True
|
||||
|
||||
def test_signal_group_allowed_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"SIGNAL_GROUP_ALLOWED_USERS": "user1"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
|
||||
def test_telegram_allow_all_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"TELEGRAM_ALLOW_ALL_USERS": "true"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
|
||||
def test_gateway_allow_all_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"GATEWAY_ALLOW_ALL_USERS": "yes"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
@@ -69,7 +69,8 @@ class TestApiServerPlatformConfig:
|
||||
|
||||
class TestApiServerAdapterToolset:
|
||||
@patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", True)
|
||||
def test_create_agent_uses_api_server_toolset(self):
|
||||
def test_create_agent_reads_config_toolsets(self):
|
||||
"""API server resolves toolsets from config like all other platforms."""
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
@@ -77,17 +78,52 @@ class TestApiServerAdapterToolset:
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs") as mock_kwargs, \
|
||||
patch("gateway.run._resolve_gateway_model") as mock_model, \
|
||||
patch("gateway.run._load_gateway_config") as mock_config, \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
|
||||
mock_kwargs.return_value = {"api_key": "test-key", "base_url": None,
|
||||
"provider": None, "api_mode": None,
|
||||
"command": None, "args": []}
|
||||
mock_model.return_value = "test/model"
|
||||
# No platform_toolsets override — should fall back to hermes-api-server default
|
||||
mock_config.return_value = {}
|
||||
mock_agent_cls.return_value = MagicMock()
|
||||
|
||||
adapter._create_agent()
|
||||
|
||||
mock_agent_cls.assert_called_once()
|
||||
call_kwargs = mock_agent_cls.call_args
|
||||
assert call_kwargs.kwargs.get("enabled_toolsets") == ["hermes-api-server"]
|
||||
toolsets = call_kwargs.kwargs.get("enabled_toolsets")
|
||||
assert isinstance(toolsets, list)
|
||||
assert len(toolsets) > 0
|
||||
assert call_kwargs.kwargs.get("platform") == "api_server"
|
||||
|
||||
@patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", True)
|
||||
def test_create_agent_respects_config_override(self):
|
||||
"""User can override API server toolsets via platform_toolsets in config.yaml."""
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
adapter = APIServerAdapter(PlatformConfig())
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs") as mock_kwargs, \
|
||||
patch("gateway.run._resolve_gateway_model") as mock_model, \
|
||||
patch("gateway.run._load_gateway_config") as mock_config, \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
|
||||
mock_kwargs.return_value = {"api_key": "test-key", "base_url": None,
|
||||
"provider": None, "api_mode": None,
|
||||
"command": None, "args": []}
|
||||
mock_model.return_value = "test/model"
|
||||
# User overrides with just web and terminal
|
||||
mock_config.return_value = {
|
||||
"platform_toolsets": {"api_server": ["web", "terminal"]}
|
||||
}
|
||||
mock_agent_cls.return_value = MagicMock()
|
||||
|
||||
adapter._create_agent()
|
||||
|
||||
mock_agent_cls.assert_called_once()
|
||||
call_kwargs = mock_agent_cls.call_args
|
||||
toolsets = call_kwargs.kwargs.get("enabled_toolsets")
|
||||
assert sorted(toolsets) == ["terminal", "web"]
|
||||
|
||||
@@ -7,11 +7,21 @@ Verifies that:
|
||||
3. The flush still works normally when memory files don't exist
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_dotenv(monkeypatch):
|
||||
"""gateway.run imports dotenv at module level; stub it so tests run without the package."""
|
||||
fake = types.ModuleType("dotenv")
|
||||
fake.load_dotenv = lambda *a, **kw: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake)
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
@@ -57,105 +67,151 @@ class TestCronSessionBypass:
|
||||
runner.session_store.load_transcript.assert_called_once_with("session_abc123")
|
||||
|
||||
|
||||
def _make_flush_context(monkeypatch, memory_dir=None):
|
||||
"""Return (runner, tmp_agent, fake_run_agent) with run_agent mocked in sys.modules."""
|
||||
tmp_agent = MagicMock()
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = MagicMock(return_value=tmp_agent)
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
return runner, tmp_agent, memory_dir
|
||||
|
||||
|
||||
class TestMemoryInjection:
|
||||
"""The flush prompt should include current memory state from disk."""
|
||||
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path):
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path, monkeypatch):
|
||||
"""When memory files exist, their content appears in the flush prompt."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode")
|
||||
(memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST")
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch, memory_dir)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Intercept `from tools.memory_tool import MEMORY_DIR` inside the function
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_123")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
call_kwargs = tmp_agent.run_conversation.call_args.kwargs
|
||||
flush_prompt = call_kwargs.get("user_message", "")
|
||||
|
||||
# Verify both memory sections appear in the prompt
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
|
||||
assert "Agent knows Python" in flush_prompt
|
||||
assert "User prefers dark mode" in flush_prompt
|
||||
assert "Name: Alice" in flush_prompt
|
||||
assert "Timezone: PST" in flush_prompt
|
||||
# Verify the stale-overwrite warning is present
|
||||
assert "Do NOT overwrite or remove entries" in flush_prompt
|
||||
assert "current live state of memory" in flush_prompt
|
||||
|
||||
def test_flush_works_without_memory_files(self, tmp_path):
|
||||
def test_flush_works_without_memory_files(self, tmp_path, monkeypatch):
|
||||
"""When no memory files exist, flush still runs without the guard."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
empty_dir = tmp_path / "no_memories"
|
||||
empty_dir.mkdir()
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_456")
|
||||
|
||||
# Should still run, just without the memory guard section
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "Do NOT overwrite or remove entries" not in flush_prompt
|
||||
assert "Review the conversation above" in flush_prompt
|
||||
|
||||
def test_empty_memory_files_no_injection(self, tmp_path):
|
||||
def test_empty_memory_files_no_injection(self, tmp_path, monkeypatch):
|
||||
"""Empty memory files should not trigger the guard section."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("")
|
||||
(memory_dir / "USER.md").write_text(" \n ") # whitespace only
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_789")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
# No memory content → no guard section
|
||||
assert "current live state of memory" not in flush_prompt
|
||||
|
||||
|
||||
class TestFlushAgentSilenced:
|
||||
"""The flush agent must not produce any terminal output."""
|
||||
|
||||
def test_print_fn_set_to_noop(self, tmp_path, monkeypatch):
|
||||
"""_print_fn on the flush agent must be a no-op so tool output never leaks."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
captured_agent = {}
|
||||
|
||||
def _fake_ai_agent(*args, **kwargs):
|
||||
agent = MagicMock()
|
||||
captured_agent["instance"] = agent
|
||||
return agent
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _fake_ai_agent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=tmp_path)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_silent")
|
||||
|
||||
agent = captured_agent["instance"]
|
||||
assert agent._print_fn is not None, "_print_fn should be overridden to suppress output"
|
||||
# Confirm it is callable and produces no output (no exception)
|
||||
agent._print_fn("should be silenced")
|
||||
|
||||
def test_kawaii_spinner_respects_print_fn(self):
|
||||
"""KawaiiSpinner must route all output through print_fn when supplied."""
|
||||
from agent.display import KawaiiSpinner
|
||||
|
||||
written = []
|
||||
spinner = KawaiiSpinner("test", print_fn=lambda *a, **kw: written.append(a))
|
||||
spinner._write("hello")
|
||||
assert written == [("hello",)], "spinner should route through print_fn"
|
||||
|
||||
# A no-op print_fn must produce no output to stdout
|
||||
import io, sys
|
||||
buf = io.StringIO()
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = buf
|
||||
try:
|
||||
silent_spinner = KawaiiSpinner("silent", print_fn=lambda *a, **kw: None)
|
||||
silent_spinner._write("should not appear")
|
||||
silent_spinner.stop("done")
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
assert buf.getvalue() == "", "no-op print_fn spinner must not write to stdout"
|
||||
|
||||
|
||||
class TestFlushPromptStructure:
|
||||
"""Verify the flush prompt retains its core instructions."""
|
||||
|
||||
def test_core_instructions_present(self):
|
||||
def test_core_instructions_present(self, monkeypatch):
|
||||
"""The flush prompt should still contain the original guidance."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Make the import fail gracefully so we test without memory files
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_struct")
|
||||
|
||||
558
tests/gateway/test_media_download_retry.py
Normal file
558
tests/gateway/test_media_download_retry.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""
|
||||
Tests for media download retry logic added in PR #2982.
|
||||
|
||||
Covers:
|
||||
- gateway/platforms/base.py: cache_image_from_url
|
||||
- gateway/platforms/slack.py: SlackAdapter._download_slack_file
|
||||
SlackAdapter._download_slack_file_bytes
|
||||
- gateway/platforms/mattermost.py: MattermostAdapter._send_url_as_file
|
||||
|
||||
All async tests use asyncio.run() directly — pytest-asyncio is not installed
|
||||
in this environment.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for building httpx exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_http_status_error(status_code: int) -> httpx.HTTPStatusError:
|
||||
request = httpx.Request("GET", "http://example.com/img.jpg")
|
||||
response = httpx.Response(status_code=status_code, request=request)
|
||||
return httpx.HTTPStatusError(
|
||||
f"HTTP {status_code}", request=request, response=response
|
||||
)
|
||||
|
||||
|
||||
def _make_timeout_error() -> httpx.TimeoutException:
|
||||
return httpx.TimeoutException("timed out")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cache_image_from_url (base.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheImageFromUrl:
|
||||
"""Tests for gateway.platforms.base.cache_image_from_url"""
|
||||
|
||||
def test_success_on_first_attempt(self, tmp_path, monkeypatch):
|
||||
"""A clean 200 response caches the image and returns a path."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"\xff\xd8\xff fake jpeg"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
return await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A timeout on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"image data"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_timeout_error(), fake_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
return await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A 429 response on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"image data"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_http_status_error(429), ok_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
return await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries_exhausted(self, tmp_path, monkeypatch):
|
||||
"""Timeout on every attempt raises after all retries are consumed."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
asyncio.run(run())
|
||||
|
||||
# 3 total calls: initial + 2 retries
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
def test_non_retryable_4xx_raises_immediately(self, tmp_path, monkeypatch):
|
||||
"""A 404 (non-retryable) is raised immediately without any retry."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_http_status_error(404))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
asyncio.run(run())
|
||||
|
||||
# Only 1 attempt, no sleep
|
||||
assert mock_client.get.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slack mock setup (mirrors existing test_slack.py approach)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_slack_mock():
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return
|
||||
slack_bolt = MagicMock()
|
||||
slack_bolt.async_app.AsyncApp = MagicMock
|
||||
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
|
||||
slack_sdk = MagicMock()
|
||||
slack_sdk.web.async_client.AsyncWebClient = MagicMock
|
||||
for name, mod in [
|
||||
("slack_bolt", slack_bolt),
|
||||
("slack_bolt.async_app", slack_bolt.async_app),
|
||||
("slack_bolt.adapter", slack_bolt.adapter),
|
||||
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
|
||||
("slack_bolt.adapter.socket_mode.async_handler",
|
||||
slack_bolt.adapter.socket_mode.async_handler),
|
||||
("slack_sdk", slack_sdk),
|
||||
("slack_sdk.web", slack_sdk.web),
|
||||
("slack_sdk.web.async_client", slack_sdk.web.async_client),
|
||||
]:
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_ensure_slack_mock()
|
||||
|
||||
import gateway.platforms.slack as _slack_mod # noqa: E402
|
||||
_slack_mod.SLACK_AVAILABLE = True
|
||||
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
from gateway.config import Platform, PlatformConfig # noqa: E402
|
||||
|
||||
|
||||
def _make_slack_adapter():
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake-token")
|
||||
adapter = SlackAdapter(config)
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.client = AsyncMock()
|
||||
adapter._bot_user_id = "U_BOT"
|
||||
adapter._running = True
|
||||
return adapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SlackAdapter._download_slack_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSlackDownloadSlackFile:
|
||||
"""Tests for SlackAdapter._download_slack_file"""
|
||||
|
||||
def test_success_on_first_attempt(self, tmp_path, monkeypatch):
|
||||
"""Successful download on first try returns a cached file path."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"fake image bytes"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
return await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""Timeout on first attempt triggers retry; success on second."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"image bytes"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_timeout_error(), fake_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
return await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_raises_after_max_retries(self, tmp_path, monkeypatch):
|
||||
"""Timeout on every attempt eventually raises after 3 total tries."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
asyncio.run(run())
|
||||
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
def test_non_retryable_403_raises_immediately(self, tmp_path, monkeypatch):
|
||||
"""A 403 is not retried; it raises immediately."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_http_status_error(403))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
asyncio.run(run())
|
||||
|
||||
assert mock_client.get.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SlackAdapter._download_slack_file_bytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSlackDownloadSlackFileBytes:
|
||||
"""Tests for SlackAdapter._download_slack_file_bytes"""
|
||||
|
||||
def test_success_returns_bytes(self):
|
||||
"""Successful download returns raw bytes."""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"raw bytes here"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
return await adapter._download_slack_file_bytes(
|
||||
"https://files.slack.com/file.bin"
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result == b"raw bytes here"
|
||||
|
||||
def test_retries_on_429_then_succeeds(self):
|
||||
"""429 on first attempt is retried; raw bytes returned on second."""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"final bytes"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_http_status_error(429), ok_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._download_slack_file_bytes(
|
||||
"https://files.slack.com/file.bin"
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result == b"final bytes"
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries(self):
|
||||
"""Persistent timeouts raise after all 3 attempts are exhausted."""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await adapter._download_slack_file_bytes(
|
||||
"https://files.slack.com/file.bin"
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
asyncio.run(run())
|
||||
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MattermostAdapter._send_url_as_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mm_adapter():
|
||||
"""Build a minimal MattermostAdapter with mocked internals."""
|
||||
from gateway.platforms.mattermost import MattermostAdapter
|
||||
config = PlatformConfig(
|
||||
enabled=True, token="mm-token-fake",
|
||||
extra={"url": "https://mm.example.com"},
|
||||
)
|
||||
adapter = MattermostAdapter(config)
|
||||
adapter._session = MagicMock()
|
||||
adapter._upload_file = AsyncMock(return_value="file-id-123")
|
||||
adapter._api_post = AsyncMock(return_value={"id": "post-id-abc"})
|
||||
adapter.send = AsyncMock(return_value=MagicMock(success=True))
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_aiohttp_resp(status: int, content: bytes = b"file bytes",
|
||||
content_type: str = "image/jpeg"):
|
||||
"""Build a context-manager mock for an aiohttp response."""
|
||||
resp = MagicMock()
|
||||
resp.status = status
|
||||
resp.content_type = content_type
|
||||
resp.read = AsyncMock(return_value=content)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return resp
|
||||
|
||||
|
||||
class TestMattermostSendUrlAsFile:
|
||||
"""Tests for MattermostAdapter._send_url_as_file"""
|
||||
|
||||
def test_success_on_first_attempt(self):
|
||||
"""200 on first attempt → file uploaded and post created."""
|
||||
adapter = _make_mm_adapter()
|
||||
resp = _make_aiohttp_resp(200)
|
||||
adapter._session.get = MagicMock(return_value=resp)
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", "caption", None
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result.success
|
||||
adapter._upload_file.assert_called_once()
|
||||
adapter._api_post.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self):
|
||||
"""429 on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_429 = _make_aiohttp_resp(429)
|
||||
resp_200 = _make_aiohttp_resp(200)
|
||||
adapter._session.get = MagicMock(side_effect=[resp_429, resp_200])
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", mock_sleep):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result.success
|
||||
assert adapter._session.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_500_then_succeeds(self):
|
||||
"""5xx on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_500 = _make_aiohttp_resp(500)
|
||||
resp_200 = _make_aiohttp_resp(200)
|
||||
adapter._session.get = MagicMock(side_effect=[resp_500, resp_200])
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result.success
|
||||
assert adapter._session.get.call_count == 2
|
||||
|
||||
def test_falls_back_to_text_after_max_retries_on_5xx(self):
|
||||
"""Three consecutive 500s exhaust retries; falls back to send() with URL text."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_500 = _make_aiohttp_resp(500)
|
||||
adapter._session.get = MagicMock(return_value=resp_500)
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", "my caption", None
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
adapter.send.assert_called_once()
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_falls_back_on_client_error(self):
|
||||
"""aiohttp.ClientError on every attempt falls back to send() with URL."""
|
||||
import aiohttp
|
||||
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
error_resp = MagicMock()
|
||||
error_resp.__aenter__ = AsyncMock(
|
||||
side_effect=aiohttp.ClientConnectionError("connection refused")
|
||||
)
|
||||
error_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
adapter._session.get = MagicMock(return_value=error_resp)
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
adapter.send.assert_called_once()
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_non_retryable_404_falls_back_immediately(self):
|
||||
"""404 is non-retryable (< 500, != 429); send() is called right away."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_404 = _make_aiohttp_resp(404)
|
||||
adapter._session.get = MagicMock(return_value=resp_404)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", mock_sleep):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
adapter.send.assert_called_once()
|
||||
# No sleep — fell back on first attempt
|
||||
mock_sleep.assert_not_called()
|
||||
assert adapter._session.get.call_count == 1
|
||||
231
tests/gateway/test_send_retry.py
Normal file
231
tests/gateway/test_send_retry.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Tests for BasePlatformAdapter._send_with_retry and _is_retryable_error.
|
||||
|
||||
Verifies that:
|
||||
- Transient network errors trigger retry with backoff
|
||||
- Permanent errors fall back to plain-text immediately (no retry)
|
||||
- User receives a delivery-failure notice when all retries are exhausted
|
||||
- Successful sends on retry return success
|
||||
- SendResult.retryable flag is respected
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult, _RETRYABLE_ERROR_PATTERNS
|
||||
from gateway.platforms.base import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal concrete adapter for testing (no real network)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
cfg = PlatformConfig()
|
||||
super().__init__(cfg, Platform.TELEGRAM)
|
||||
self._send_results = [] # queue of SendResult to return per call
|
||||
self._send_calls = [] # record of (chat_id, content) sent
|
||||
|
||||
def _next_result(self) -> SendResult:
|
||||
if self._send_results:
|
||||
return self._send_results.pop(0)
|
||||
return SendResult(success=True, message_id="ok")
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None, **kwargs) -> SendResult:
|
||||
self._send_calls.append((chat_id, content))
|
||||
return self._next_result()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
pass
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None) -> None:
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"name": "test", "type": "direct", "chat_id": chat_id}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_retryable_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsRetryableError:
|
||||
def test_none_is_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error(None)
|
||||
|
||||
def test_empty_string_is_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("")
|
||||
|
||||
@pytest.mark.parametrize("pattern", _RETRYABLE_ERROR_PATTERNS)
|
||||
def test_known_pattern_is_retryable(self, pattern):
|
||||
assert _StubAdapter._is_retryable_error(f"httpx.{pattern.title()}: connection dropped")
|
||||
|
||||
def test_permission_error_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("Forbidden: bot was blocked by the user")
|
||||
|
||||
def test_bad_request_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("Bad Request: can't parse entities")
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _StubAdapter._is_retryable_error("CONNECTERROR: host unreachable")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — success on first attempt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetrySuccess:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_first_attempt(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [SendResult(success=True, message_id="123")]
|
||||
result = await adapter._send_with_retry("chat1", "hello")
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_message_id(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [SendResult(success=True, message_id="abc")]
|
||||
result = await adapter._send_with_retry("chat1", "hi")
|
||||
assert result.message_id == "abc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — network error with successful retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryNetworkRetry:
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_connect_error_and_succeeds(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="httpx.ConnectError: connection refused"),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2 # initial + 1 retry
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_timeout_and_succeeds(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="ReadTimeout: request timed out"),
|
||||
SendResult(success=False, error="ReadTimeout: request timed out"),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=3, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_flag_respected(self):
|
||||
"""SendResult.retryable=True should trigger retry even if error string doesn't match."""
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="internal platform error", retryable=True),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_to_nonnetwork_transition_falls_back_to_plaintext(self):
|
||||
"""If error switches from network to formatting mid-retry, fall through to plain-text fallback."""
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="httpx.ConnectError: host unreachable"),
|
||||
SendResult(success=False, error="Bad Request: can't parse entities"),
|
||||
SendResult(success=True, message_id="fallback_ok"), # plain-text fallback
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "**bold**", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
# 3 calls: initial (network) + 1 retry (non-network, breaks loop) + plain-text fallback
|
||||
assert len(adapter._send_calls) == 3
|
||||
assert "plain text" in adapter._send_calls[-1][1].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — all retries exhausted → user notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryExhausted:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_user_notice_after_exhaustion(self):
|
||||
adapter = _StubAdapter()
|
||||
network_err = SendResult(success=False, error="httpx.ConnectError: host unreachable")
|
||||
# initial + 2 retries + notice attempt
|
||||
adapter._send_results = [network_err, network_err, network_err, SendResult(success=True)]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
# Result is the last failed one (before notice)
|
||||
assert not result.success
|
||||
# 4 total calls: 1 initial + 2 retries + 1 notice
|
||||
assert len(adapter._send_calls) == 4
|
||||
# The notice content should mention delivery failure
|
||||
notice_content = adapter._send_calls[-1][1]
|
||||
assert "delivery failed" in notice_content.lower() or "Message delivery failed" in notice_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notice_send_exception_doesnt_propagate(self):
|
||||
"""If the notice itself throws, _send_with_retry should not raise."""
|
||||
adapter = _StubAdapter()
|
||||
network_err = SendResult(success=False, error="ConnectError")
|
||||
adapter._send_results = [network_err, network_err, network_err]
|
||||
|
||||
original_send = adapter.send
|
||||
call_count = [0]
|
||||
|
||||
async def send_with_notice_failure(chat_id, content, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] > 3:
|
||||
raise RuntimeError("notice send also failed")
|
||||
return network_err
|
||||
|
||||
adapter.send = send_with_notice_failure
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert not result.success # still failed, but no exception raised
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — non-network failure → plain-text fallback (no retry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryFallback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_network_error_falls_back_immediately(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="Bad Request: can't parse entities"),
|
||||
SendResult(success=True, message_id="fallback_ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
result = await adapter._send_with_retry("chat1", "**bold**", max_retries=2, base_delay=0)
|
||||
# No sleep — no retry loop for non-network errors
|
||||
mock_sleep.assert_not_called()
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2
|
||||
# Fallback content should be plain-text notice
|
||||
assert "plain text" in adapter._send_calls[1][1].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_failure_logged_but_not_raised(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="Forbidden: bot blocked"),
|
||||
SendResult(success=False, error="Forbidden: bot blocked"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2)
|
||||
assert not result.success
|
||||
assert len(adapter._send_calls) == 2 # original + fallback only
|
||||
@@ -846,7 +846,7 @@ class TestLastPromptTokens:
|
||||
|
||||
store.update_session("k1", model="openai/gpt-5.4")
|
||||
|
||||
store._db.update_token_counts.assert_called_once_with(
|
||||
store._db.set_token_counts.assert_called_once_with(
|
||||
"s1",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
@@ -858,4 +858,48 @@ class TestLastPromptTokens:
|
||||
billing_provider=None,
|
||||
billing_base_url=None,
|
||||
model="openai/gpt-5.4",
|
||||
absolute=True,
|
||||
)
|
||||
|
||||
|
||||
class TestRewriteTranscriptPreservesReasoning:
|
||||
"""rewrite_transcript must not drop reasoning fields from SQLite."""
|
||||
|
||||
def test_reasoning_survives_rewrite(self, tmp_path):
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "test.db")
|
||||
session_id = "reasoning-test"
|
||||
db.create_session(session_id=session_id, source="cli")
|
||||
|
||||
# Insert a message WITH all three reasoning fields
|
||||
db.append_message(
|
||||
session_id=session_id,
|
||||
role="assistant",
|
||||
content="The answer is 42.",
|
||||
reasoning="I need to think step by step.",
|
||||
reasoning_details=[{"type": "summary", "text": "step by step"}],
|
||||
codex_reasoning_items=[{"id": "r1", "type": "reasoning"}],
|
||||
)
|
||||
|
||||
# Verify all three were stored
|
||||
before = db.get_messages_as_conversation(session_id)
|
||||
assert before[0].get("reasoning") == "I need to think step by step."
|
||||
assert before[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
|
||||
assert before[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
|
||||
|
||||
# Now simulate /retry: build the SessionStore and call rewrite_transcript
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = db
|
||||
store._loaded = True
|
||||
|
||||
# rewrite_transcript receives the messages that load_transcript returned
|
||||
store.rewrite_transcript(session_id, before)
|
||||
|
||||
# Load again — all three reasoning fields must survive
|
||||
after = db.get_messages_as_conversation(session_id)
|
||||
assert after[0].get("reasoning") == "I need to think step by step."
|
||||
assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
|
||||
assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
|
||||
|
||||
110
tests/gateway/test_session_info.py
Normal file
110
tests/gateway/test_session_info.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Tests for GatewayRunner._format_session_info — session config surfacing."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from pathlib import Path
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def runner():
|
||||
"""Create a bare GatewayRunner without __init__."""
|
||||
return GatewayRunner.__new__(GatewayRunner)
|
||||
|
||||
|
||||
def _patch_info(tmp_path, config_yaml, model, runtime):
|
||||
"""Return a context-manager stack that patches _format_session_info deps."""
|
||||
cfg_path = tmp_path / "config.yaml"
|
||||
if config_yaml is not None:
|
||||
cfg_path.write_text(config_yaml)
|
||||
return (
|
||||
patch("gateway.run._hermes_home", tmp_path),
|
||||
patch("gateway.run._resolve_gateway_model", return_value=model),
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value=runtime),
|
||||
)
|
||||
|
||||
|
||||
class TestFormatSessionInfo:
|
||||
|
||||
def test_includes_model_name(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: anthropic/claude-opus-4.6\n provider: openrouter\n",
|
||||
"anthropic/claude-opus-4.6",
|
||||
{"provider": "openrouter", "base_url": "https://openrouter.ai/api/v1", "api_key": "k"})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "claude-opus-4.6" in info
|
||||
|
||||
def test_includes_provider(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n provider: openrouter\n",
|
||||
"test-model",
|
||||
{"provider": "openrouter", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "openrouter" in info
|
||||
|
||||
def test_config_context_length(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n context_length: 32768\n",
|
||||
"test-model",
|
||||
{"provider": "custom", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "32K" in info
|
||||
assert "config" in info
|
||||
|
||||
def test_default_fallback_hint(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: unknown-model-xyz\n",
|
||||
"unknown-model-xyz",
|
||||
{"provider": "", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "128K" in info
|
||||
assert "model.context_length" in info
|
||||
|
||||
def test_local_endpoint_shown(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(
|
||||
tmp_path,
|
||||
"model:\n default: qwen3:8b\n provider: custom\n base_url: http://localhost:11434/v1\n context_length: 8192\n",
|
||||
"qwen3:8b",
|
||||
{"provider": "custom", "base_url": "http://localhost:11434/v1", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "localhost:11434" in info
|
||||
assert "8K" in info
|
||||
|
||||
def test_cloud_endpoint_hidden(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n provider: openrouter\n",
|
||||
"test-model",
|
||||
{"provider": "openrouter", "base_url": "https://openrouter.ai/api/v1", "api_key": "k"})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "Endpoint" not in info
|
||||
|
||||
def test_million_context_format(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n context_length: 1000000\n",
|
||||
"test-model",
|
||||
{"provider": "", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "1.0M" in info
|
||||
|
||||
def test_missing_config(self, runner, tmp_path):
|
||||
"""No config.yaml should not crash."""
|
||||
p1, p2, p3 = _patch_info(tmp_path, None, # don't create config
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
{"provider": "openrouter", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "Model" in info
|
||||
assert "Context" in info
|
||||
|
||||
def test_runtime_resolution_failure_doesnt_crash(self, runner, tmp_path):
|
||||
"""If runtime resolution raises, should still produce output."""
|
||||
cfg_path = tmp_path / "config.yaml"
|
||||
cfg_path.write_text("model:\n default: test-model\n context_length: 4096\n")
|
||||
with patch("gateway.run._hermes_home", tmp_path), \
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"), \
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", side_effect=RuntimeError("no creds")):
|
||||
info = runner._format_session_info()
|
||||
assert "4K" in info
|
||||
assert "config" in info
|
||||
@@ -94,7 +94,7 @@ class TestOfferOpenclawMigration:
|
||||
fake_mod.Migrator.assert_called_once()
|
||||
call_kwargs = fake_mod.Migrator.call_args[1]
|
||||
assert call_kwargs["execute"] is True
|
||||
assert call_kwargs["overwrite"] is False
|
||||
assert call_kwargs["overwrite"] is True
|
||||
assert call_kwargs["migrate_secrets"] is True
|
||||
assert call_kwargs["preset_name"] == "full"
|
||||
fake_migrator.migrate.assert_called_once()
|
||||
@@ -285,3 +285,182 @@ class TestSetupWizardOpenclawIntegration:
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
mock_migration.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_section_config_summary / _skip_configured_section — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSectionConfigSummary:
|
||||
"""Test the _get_section_config_summary helper."""
|
||||
|
||||
def test_model_returns_none_without_api_key(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "model")
|
||||
assert result is None
|
||||
|
||||
def test_model_returns_summary_with_api_key(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"model": "openai/gpt-4"}, "model"
|
||||
)
|
||||
assert result == "openai/gpt-4"
|
||||
|
||||
def test_model_returns_dict_default_key(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENAI_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"model": {"default": "claude-opus-4", "provider": "anthropic"}},
|
||||
"model",
|
||||
)
|
||||
assert result == "claude-opus-4"
|
||||
|
||||
def test_terminal_always_returns(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"terminal": {"backend": "docker"}}, "terminal"
|
||||
)
|
||||
assert result == "backend: docker"
|
||||
|
||||
def test_agent_always_returns(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"agent": {"max_turns": 120}}, "agent"
|
||||
)
|
||||
assert result == "max turns: 120"
|
||||
|
||||
def test_gateway_returns_none_without_tokens(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert result is None
|
||||
|
||||
def test_gateway_lists_platforms(self):
|
||||
def env_side(key):
|
||||
if key == "TELEGRAM_BOT_TOKEN":
|
||||
return "tok123"
|
||||
if key == "DISCORD_BOT_TOKEN":
|
||||
return "disc456"
|
||||
return ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert "Telegram" in result
|
||||
assert "Discord" in result
|
||||
|
||||
def test_tools_returns_none_without_keys(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "tools")
|
||||
assert result is None
|
||||
|
||||
def test_tools_lists_configured(self):
|
||||
def env_side(key):
|
||||
return "key" if key == "BROWSERBASE_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "tools")
|
||||
assert "Browser" in result
|
||||
|
||||
|
||||
class TestSkipConfiguredSection:
|
||||
"""Test the _skip_configured_section helper."""
|
||||
|
||||
def test_returns_false_when_not_configured(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._skip_configured_section({}, "model", "Model")
|
||||
assert result is False
|
||||
|
||||
def test_returns_true_when_user_skips(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=False),
|
||||
):
|
||||
result = setup_mod._skip_configured_section(
|
||||
{"model": "openai/gpt-4"}, "model", "Model"
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_returns_false_when_user_wants_reconfig(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=True),
|
||||
):
|
||||
result = setup_mod._skip_configured_section(
|
||||
{"model": "openai/gpt-4"}, "model", "Model"
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestSetupWizardSkipsConfiguredSections:
|
||||
"""After migration, already-configured sections should offer skip."""
|
||||
|
||||
def test_sections_skipped_when_migration_imported_settings(self, tmp_path):
|
||||
"""When migration ran and API key exists, model section should be skippable.
|
||||
|
||||
Simulates the real flow: get_env_value returns "" during the is_existing
|
||||
check (before migration), then returns a key after migration imported it.
|
||||
"""
|
||||
args = _first_time_args()
|
||||
|
||||
# Track whether migration has "run" — after it does, API key is available
|
||||
migration_done = {"value": False}
|
||||
|
||||
def env_side(key):
|
||||
if migration_done["value"] and key == "OPENROUTER_API_KEY":
|
||||
return "sk-xxx"
|
||||
return ""
|
||||
|
||||
def fake_migration(hermes_home):
|
||||
migration_done["value"] = True
|
||||
return True
|
||||
|
||||
reloaded_config = {"model": "openai/gpt-4"}
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(
|
||||
setup_mod, "load_config",
|
||||
side_effect=[{}, reloaded_config],
|
||||
),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "is_interactive_stdin", return_value=True),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
patch("builtins.input", return_value=""),
|
||||
# Migration succeeds and flips the env_side flag
|
||||
patch.object(
|
||||
setup_mod, "_offer_openclaw_migration",
|
||||
side_effect=fake_migration,
|
||||
),
|
||||
# User says No to all reconfig prompts
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=False),
|
||||
patch.object(setup_mod, "setup_model_provider") as mock_model,
|
||||
patch.object(setup_mod, "setup_terminal_backend") as mock_terminal,
|
||||
patch.object(setup_mod, "setup_agent_settings") as mock_agent,
|
||||
patch.object(setup_mod, "setup_gateway") as mock_gateway,
|
||||
patch.object(setup_mod, "setup_tools") as mock_tools,
|
||||
patch.object(setup_mod, "save_config"),
|
||||
patch.object(setup_mod, "_print_setup_summary"),
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
# Model has API key → skip offered, user said No → section NOT called
|
||||
mock_model.assert_not_called()
|
||||
# Terminal/agent always have a summary → skip offered, user said No
|
||||
mock_terminal.assert_not_called()
|
||||
mock_agent.assert_not_called()
|
||||
# Gateway has no tokens (env_side returns "" for gateway keys) → section runs
|
||||
mock_gateway.assert_called_once()
|
||||
# Tools have no keys → section runs
|
||||
mock_tools.assert_called_once()
|
||||
|
||||
@@ -801,6 +801,48 @@ class TestConvertMessages:
|
||||
assert all(not (b.get("type") == "text" and b.get("text") == "") for b in assistant_blocks)
|
||||
assert any(b.get("type") == "tool_use" for b in assistant_blocks)
|
||||
|
||||
def test_empty_user_message_string_gets_placeholder(self):
|
||||
"""Empty user message strings should get '(empty message)' placeholder.
|
||||
|
||||
Anthropic rejects requests with empty user message content.
|
||||
Regression test for #3143 — Discord @mention-only messages.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "user", "content": ""},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[0]["content"] == "(empty message)"
|
||||
|
||||
def test_whitespace_only_user_message_gets_placeholder(self):
|
||||
"""Whitespace-only user messages should also get placeholder."""
|
||||
messages = [
|
||||
{"role": "user", "content": " \n\t "},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["content"] == "(empty message)"
|
||||
|
||||
def test_empty_user_message_list_gets_placeholder(self):
|
||||
"""Empty content list for user messages should get placeholder block."""
|
||||
messages = [
|
||||
{"role": "user", "content": []},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["role"] == "user"
|
||||
assert isinstance(result[0]["content"], list)
|
||||
assert len(result[0]["content"]) == 1
|
||||
assert result[0]["content"][0] == {"type": "text", "text": "(empty message)"}
|
||||
|
||||
def test_user_message_with_empty_text_blocks_gets_placeholder(self):
|
||||
"""User message with only empty text blocks should get placeholder."""
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "text", "text": ""}, {"type": "text", "text": " "}]},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["role"] == "user"
|
||||
assert isinstance(result[0]["content"], list)
|
||||
assert result[0]["content"] == [{"type": "text", "text": "(empty message)"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build kwargs
|
||||
|
||||
@@ -217,10 +217,17 @@ def test_529_overloaded_is_retried_and_recovers(monkeypatch):
|
||||
|
||||
|
||||
def test_429_exhausts_all_retries_before_raising(monkeypatch):
|
||||
"""429 must retry max_retries times, not abort on first attempt."""
|
||||
"""429 must retry max_retries times, then return a failed result.
|
||||
|
||||
The agent no longer re-raises after exhausting retries — it returns a
|
||||
result dict with the error in final_response. This changed when the
|
||||
fallback-provider feature was added (the agent tries a fallback before
|
||||
giving up, and returns a result dict either way).
|
||||
"""
|
||||
agent_cls = _make_agent_cls(_RateLimitError) # always fails
|
||||
with pytest.raises(_RateLimitError):
|
||||
_run_with_agent(monkeypatch, agent_cls)
|
||||
result = _run_with_agent(monkeypatch, agent_cls)
|
||||
resp = str(result.get("final_response", ""))
|
||||
assert "429" in resp or "retries" in resp.lower()
|
||||
|
||||
|
||||
def test_400_bad_request_is_non_retryable(monkeypatch):
|
||||
|
||||
@@ -96,6 +96,59 @@ class TestVerboseAndToolProgress:
|
||||
assert cli.tool_progress_mode in ("off", "new", "all", "verbose")
|
||||
|
||||
|
||||
class TestBusyInputMode:
|
||||
def test_default_busy_input_mode_is_interrupt(self):
|
||||
cli = _make_cli()
|
||||
assert cli.busy_input_mode == "interrupt"
|
||||
|
||||
def test_busy_input_mode_queue_is_honored(self):
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "queue"}})
|
||||
assert cli.busy_input_mode == "queue"
|
||||
|
||||
def test_unknown_busy_input_mode_falls_back_to_interrupt(self):
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "bogus"}})
|
||||
assert cli.busy_input_mode == "interrupt"
|
||||
|
||||
def test_queue_command_works_while_busy(self):
|
||||
"""When agent is running, /queue should still put the prompt in _pending_input."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
cli.process_command("/queue follow up")
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
|
||||
def test_queue_command_works_while_idle(self):
|
||||
"""When agent is idle, /queue should still queue (not reject)."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = False
|
||||
cli.process_command("/queue follow up")
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
|
||||
def test_queue_mode_routes_busy_enter_to_pending(self):
|
||||
"""In queue mode, Enter while busy should go to _pending_input, not _interrupt_queue."""
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "queue"}})
|
||||
cli._agent_running = True
|
||||
# Simulate what handle_enter does for non-command input while busy
|
||||
text = "follow up"
|
||||
if cli.busy_input_mode == "queue":
|
||||
cli._pending_input.put(text)
|
||||
else:
|
||||
cli._interrupt_queue.put(text)
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
assert cli._interrupt_queue.empty()
|
||||
|
||||
def test_interrupt_mode_routes_busy_enter_to_interrupt(self):
|
||||
"""In interrupt mode (default), Enter while busy goes to _interrupt_queue."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
text = "redirect"
|
||||
if cli.busy_input_mode == "queue":
|
||||
cli._pending_input.put(text)
|
||||
else:
|
||||
cli._interrupt_queue.put(text)
|
||||
assert cli._interrupt_queue.get_nowait() == "redirect"
|
||||
assert cli._pending_input.empty()
|
||||
|
||||
|
||||
class TestSingleQueryState:
|
||||
def test_voice_and_interrupt_state_initialized_before_run(self):
|
||||
"""Single-query mode calls chat() without going through run()."""
|
||||
|
||||
@@ -182,3 +182,94 @@ class TestCLIUsageReport:
|
||||
assert "Total cost:" in output
|
||||
assert "n/a" in output
|
||||
assert "Pricing unknown for glm-5" in output
|
||||
|
||||
|
||||
class TestStatusBarWidthSource:
|
||||
"""Ensure status bar fragments don't overflow the terminal width."""
|
||||
|
||||
def _make_wide_cli(self):
|
||||
from datetime import datetime, timedelta
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=100_000,
|
||||
completion_tokens=5_000,
|
||||
total_tokens=105_000,
|
||||
api_calls=20,
|
||||
context_tokens=100_000,
|
||||
context_length=200_000,
|
||||
)
|
||||
cli_obj._status_bar_visible = True
|
||||
return cli_obj
|
||||
|
||||
def test_fragments_fit_within_announced_width(self):
|
||||
"""Total fragment text length must not exceed the width used to build them."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
for width in (40, 52, 76, 80, 120, 200):
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=width)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app):
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
|
||||
total_text = "".join(text for _, text in frags)
|
||||
assert len(total_text) <= width + 4, ( # +4 for minor padding chars
|
||||
f"At width={width}, fragment total {len(total_text)} chars overflows "
|
||||
f"({total_text!r})"
|
||||
)
|
||||
|
||||
def test_fragments_use_pt_width_over_shutil(self):
|
||||
"""When prompt_toolkit reports a width, shutil.get_terminal_size must not be used."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=120)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app) as mock_get_app, \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
cli_obj._get_status_bar_fragments()
|
||||
|
||||
mock_shutil.assert_not_called()
|
||||
|
||||
def test_fragments_fall_back_to_shutil_when_no_app(self):
|
||||
"""Outside a TUI context (no running app), shutil must be used as fallback."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", side_effect=Exception("no app")), \
|
||||
patch("shutil.get_terminal_size", return_value=MagicMock(columns=100)) as mock_shutil:
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
|
||||
mock_shutil.assert_called()
|
||||
assert len(frags) > 0
|
||||
|
||||
def test_build_status_bar_text_uses_pt_width(self):
|
||||
"""_build_status_bar_text() must also prefer prompt_toolkit width."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=80)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app), \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
text = cli_obj._build_status_bar_text() # no explicit width
|
||||
|
||||
mock_shutil.assert_not_called()
|
||||
assert isinstance(text, str)
|
||||
assert len(text) > 0
|
||||
|
||||
def test_explicit_width_skips_pt_lookup(self):
|
||||
"""An explicit width= argument must bypass both PT and shutil lookups."""
|
||||
from unittest.mock import patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
with patch("prompt_toolkit.application.get_app") as mock_get_app, \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
text = cli_obj._build_status_bar_text(width=100)
|
||||
|
||||
mock_get_app.assert_not_called()
|
||||
mock_shutil.assert_not_called()
|
||||
assert len(text) > 0
|
||||
|
||||
89
tests/test_compressor_fallback_update.py
Normal file
89
tests/test_compressor_fallback_update.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Tests that _try_activate_fallback updates the context compressor."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
|
||||
def _make_agent_with_compressor() -> AIAgent:
|
||||
"""Build a minimal AIAgent with a context_compressor, skipping __init__."""
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
|
||||
# Primary model settings
|
||||
agent.model = "primary-model"
|
||||
agent.provider = "openrouter"
|
||||
agent.base_url = "https://openrouter.ai/api/v1"
|
||||
agent.api_key = "sk-primary"
|
||||
agent.api_mode = "chat_completions"
|
||||
agent.client = MagicMock()
|
||||
agent.quiet_mode = True
|
||||
|
||||
# Fallback config
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
|
||||
# Context compressor with primary model values
|
||||
compressor = ContextCompressor(
|
||||
model="primary-model",
|
||||
threshold_percent=0.50,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key="sk-primary",
|
||||
provider="openrouter",
|
||||
quiet_mode=True,
|
||||
)
|
||||
agent.context_compressor = compressor
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.resolve_provider_client")
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=128_000)
|
||||
def test_compressor_updated_on_fallback(mock_ctx_len, mock_resolve):
|
||||
"""After fallback activation, the compressor must reflect the fallback model."""
|
||||
agent = _make_agent_with_compressor()
|
||||
|
||||
assert agent.context_compressor.model == "primary-model"
|
||||
|
||||
fb_client = MagicMock()
|
||||
fb_client.base_url = "https://api.openai.com/v1"
|
||||
fb_client.api_key = "sk-fallback"
|
||||
mock_resolve.return_value = (fb_client, None)
|
||||
|
||||
agent._is_direct_openai_url = lambda url: "api.openai.com" in url
|
||||
agent._emit_status = lambda msg: None
|
||||
|
||||
result = agent._try_activate_fallback()
|
||||
|
||||
assert result is True
|
||||
assert agent._fallback_activated is True
|
||||
|
||||
c = agent.context_compressor
|
||||
assert c.model == "gpt-4o"
|
||||
assert c.base_url == "https://api.openai.com/v1"
|
||||
assert c.api_key == "sk-fallback"
|
||||
assert c.provider == "openai"
|
||||
assert c.context_length == 128_000
|
||||
assert c.threshold_tokens == int(128_000 * c.threshold_percent)
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.resolve_provider_client")
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=128_000)
|
||||
def test_compressor_not_present_does_not_crash(mock_ctx_len, mock_resolve):
|
||||
"""If the agent has no compressor, fallback should still succeed."""
|
||||
agent = _make_agent_with_compressor()
|
||||
agent.context_compressor = None
|
||||
|
||||
fb_client = MagicMock()
|
||||
fb_client.base_url = "https://api.openai.com/v1"
|
||||
fb_client.api_key = "sk-fallback"
|
||||
mock_resolve.return_value = (fb_client, None)
|
||||
|
||||
agent._is_direct_openai_url = lambda url: "api.openai.com" in url
|
||||
agent._emit_status = lambda msg: None
|
||||
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
111
tests/tools/test_config_null_guard.py
Normal file
111
tests/tools/test_config_null_guard.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Tests for config.get() null-coalescing in tool configuration.
|
||||
|
||||
YAML ``null`` values (or ``~``) for a present key make ``dict.get(key, default)``
|
||||
return ``None`` instead of the default — calling ``.lower()`` on that raises
|
||||
``AttributeError``. These tests verify the ``or`` coalescing guards.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
|
||||
# ── TTS tool ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestTTSProviderNullGuard:
|
||||
"""tools/tts_tool.py — _get_provider()"""
|
||||
|
||||
def test_explicit_null_provider_returns_default(self):
|
||||
"""YAML ``tts: {provider: null}`` should fall back to default."""
|
||||
from tools.tts_tool import _get_provider, DEFAULT_PROVIDER
|
||||
|
||||
result = _get_provider({"provider": None})
|
||||
assert result == DEFAULT_PROVIDER.lower().strip()
|
||||
|
||||
def test_missing_provider_returns_default(self):
|
||||
"""No ``provider`` key at all should also return default."""
|
||||
from tools.tts_tool import _get_provider, DEFAULT_PROVIDER
|
||||
|
||||
result = _get_provider({})
|
||||
assert result == DEFAULT_PROVIDER.lower().strip()
|
||||
|
||||
def test_valid_provider_passed_through(self):
|
||||
from tools.tts_tool import _get_provider
|
||||
|
||||
result = _get_provider({"provider": "OPENAI"})
|
||||
assert result == "openai"
|
||||
|
||||
|
||||
# ── Web tools ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestWebBackendNullGuard:
|
||||
"""tools/web_tools.py — _get_backend()"""
|
||||
|
||||
@patch("tools.web_tools._load_web_config", return_value={"backend": None})
|
||||
def test_explicit_null_backend_does_not_crash(self, _cfg):
|
||||
"""YAML ``web: {backend: null}`` should not raise AttributeError."""
|
||||
from tools.web_tools import _get_backend
|
||||
|
||||
# Should not raise — the exact return depends on env key fallback
|
||||
result = _get_backend()
|
||||
assert isinstance(result, str)
|
||||
|
||||
@patch("tools.web_tools._load_web_config", return_value={})
|
||||
def test_missing_backend_does_not_crash(self, _cfg):
|
||||
from tools.web_tools import _get_backend
|
||||
|
||||
result = _get_backend()
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# ── MCP tool ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestMCPAuthNullGuard:
|
||||
"""tools/mcp_tool.py — MCPServerTask.__init__() auth config line"""
|
||||
|
||||
def test_explicit_null_auth_does_not_crash(self):
|
||||
"""YAML ``auth: null`` in MCP server config should not raise."""
|
||||
# Test the expression directly — MCPServerTask.__init__ has many deps
|
||||
config = {"auth": None, "timeout": 30}
|
||||
auth_type = (config.get("auth") or "").lower().strip()
|
||||
assert auth_type == ""
|
||||
|
||||
def test_missing_auth_defaults_to_empty(self):
|
||||
config = {"timeout": 30}
|
||||
auth_type = (config.get("auth") or "").lower().strip()
|
||||
assert auth_type == ""
|
||||
|
||||
def test_valid_auth_passed_through(self):
|
||||
config = {"auth": "OAUTH", "timeout": 30}
|
||||
auth_type = (config.get("auth") or "").lower().strip()
|
||||
assert auth_type == "oauth"
|
||||
|
||||
|
||||
# ── Trajectory compressor ─────────────────────────────────────────────────
|
||||
|
||||
class TestTrajectoryCompressorNullGuard:
|
||||
"""trajectory_compressor.py — _detect_provider() and config loading"""
|
||||
|
||||
def test_null_base_url_does_not_crash(self):
|
||||
"""base_url=None should not crash _detect_provider()."""
|
||||
from trajectory_compressor import CompressionConfig, TrajectoryCompressor
|
||||
|
||||
config = CompressionConfig()
|
||||
config.base_url = None
|
||||
|
||||
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
|
||||
compressor.config = config
|
||||
|
||||
# Should not raise AttributeError; returns empty string (no match)
|
||||
result = compressor._detect_provider()
|
||||
assert result == ""
|
||||
|
||||
def test_config_loading_null_base_url_keeps_default(self):
|
||||
"""YAML ``summarization: {base_url: null}`` should keep default."""
|
||||
from trajectory_compressor import CompressionConfig
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
config = CompressionConfig()
|
||||
data = {"summarization": {"base_url": None}}
|
||||
|
||||
config.base_url = data["summarization"].get("base_url") or config.base_url
|
||||
assert config.base_url == OPENROUTER_BASE_URL
|
||||
@@ -185,3 +185,71 @@ class TestApplyUpdate:
|
||||
' result = 1\n'
|
||||
' return result + 1'
|
||||
)
|
||||
|
||||
|
||||
class TestAdditionOnlyHunks:
|
||||
"""Regression tests for #3081 — addition-only hunks were silently dropped."""
|
||||
|
||||
def test_addition_only_hunk_with_context_hint(self):
|
||||
"""A hunk with only + lines should insert at the context hint location."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: src/app.py
|
||||
@@ def main @@
|
||||
+def helper():
|
||||
+ return 42
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
assert len(ops[0].hunks) == 1
|
||||
|
||||
hunk = ops[0].hunks[0]
|
||||
# All lines should be additions
|
||||
assert all(l.prefix == '+' for l in hunk.lines)
|
||||
|
||||
# Apply to a file that contains the context hint
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file(self, path, **kw):
|
||||
return SimpleNamespace(
|
||||
content="def main():\n pass\n",
|
||||
error=None,
|
||||
)
|
||||
def write_file(self, path, content):
|
||||
self.written = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
assert result.success is True
|
||||
assert "def helper():" in file_ops.written
|
||||
assert "return 42" in file_ops.written
|
||||
|
||||
def test_addition_only_hunk_without_context_hint(self):
|
||||
"""A hunk with only + lines and no context hint appends at end of file."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: src/app.py
|
||||
+def new_func():
|
||||
+ return True
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file(self, path, **kw):
|
||||
return SimpleNamespace(
|
||||
content="existing = True\n",
|
||||
error=None,
|
||||
)
|
||||
def write_file(self, path, content):
|
||||
self.written = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
assert result.success is True
|
||||
assert file_ops.written.endswith("def new_func():\n return True\n")
|
||||
assert "existing = True" in file_ops.written
|
||||
|
||||
@@ -797,7 +797,7 @@ class MCPServerTask:
|
||||
"""
|
||||
self._config = config
|
||||
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
|
||||
self._auth_type = config.get("auth", "").lower().strip()
|
||||
self._auth_type = (config.get("auth") or "").lower().strip()
|
||||
|
||||
# Set up sampling handler if enabled and SDK types are available
|
||||
sampling_config = config.get("sampling", {})
|
||||
|
||||
@@ -419,6 +419,23 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
|
||||
if error:
|
||||
return False, f"Could not apply hunk: {error}"
|
||||
else:
|
||||
# Addition-only hunk (no context or removed lines).
|
||||
# Insert at the location indicated by the context hint, or at end of file.
|
||||
insert_text = '\n'.join(replace_lines)
|
||||
if hunk.context_hint:
|
||||
hint_pos = new_content.find(hunk.context_hint)
|
||||
if hint_pos != -1:
|
||||
# Insert after the line containing the context hint
|
||||
eol = new_content.find('\n', hint_pos)
|
||||
if eol != -1:
|
||||
new_content = new_content[:eol + 1] + insert_text + '\n' + new_content[eol + 1:]
|
||||
else:
|
||||
new_content = new_content + '\n' + insert_text
|
||||
else:
|
||||
new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n'
|
||||
else:
|
||||
new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n'
|
||||
|
||||
# Write new content
|
||||
write_result = file_ops.write_file(op.file_path, new_content)
|
||||
|
||||
@@ -102,7 +102,7 @@ def _load_tts_config() -> Dict[str, Any]:
|
||||
|
||||
def _get_provider(tts_config: Dict[str, Any]) -> str:
|
||||
"""Get the configured TTS provider name."""
|
||||
return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip()
|
||||
return (tts_config.get("provider") or DEFAULT_PROVIDER).lower().strip()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
|
||||
@@ -73,7 +73,7 @@ def _get_backend() -> str:
|
||||
Falls back to whichever API key is present for users who configured
|
||||
keys manually without running setup.
|
||||
"""
|
||||
configured = _load_web_config().get("backend", "").lower().strip()
|
||||
configured = (_load_web_config().get("backend") or "").lower().strip()
|
||||
if configured in ("parallel", "firecrawl", "tavily"):
|
||||
return configured
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ class CompressionConfig:
|
||||
# Summarization
|
||||
if 'summarization' in data:
|
||||
config.summarization_model = data['summarization'].get('model', config.summarization_model)
|
||||
config.base_url = data['summarization'].get('base_url', config.base_url)
|
||||
config.base_url = data['summarization'].get('base_url') or config.base_url
|
||||
config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env)
|
||||
config.temperature = data['summarization'].get('temperature', config.temperature)
|
||||
config.max_retries = data['summarization'].get('max_retries', config.max_retries)
|
||||
@@ -386,7 +386,7 @@ class TrajectoryCompressor:
|
||||
|
||||
def _detect_provider(self) -> str:
|
||||
"""Detect the provider name from the configured base_url."""
|
||||
url = self.config.base_url.lower()
|
||||
url = (self.config.base_url or "").lower()
|
||||
if "openrouter" in url:
|
||||
return "openrouter"
|
||||
if "nousresearch.com" in url:
|
||||
|
||||
Reference in New Issue
Block a user