Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4511322f56 | |||
| 934fc9df22 | |||
| 5847c180c6 | |||
| 93a0c0cddd | |||
| 23e8fdd167 | |||
| 3268b98779 | |||
| 20f381cfb6 | |||
| 77bfa252b9 | |||
| f24c00a5bf | |||
| 463239ed85 | |||
| 60cce9ca6d | |||
| 2d57946ee9 | |||
| 5f32fd8b6d | |||
| 3ea039684e | |||
| 63f0ec96ec | |||
| 1cacaccca6 | |||
| 773f3c1137 | |||
| 0cc784068d | |||
| f1b4d0b280 | |||
| 5254d0bba1 | |||
| 21c20aeaa5 | |||
| dc095f8491 | |||
| 621fd80b1e | |||
| 2b8fd9a8e3 | |||
| fef710aca8 | |||
| 4ae1334287 | |||
| db3e3aa6c5 | |||
| 633488e0c0 | |||
| 0de200cf4d | |||
| f6fdb18fe6 | |||
| b177b4abad | |||
| 232ba441d7 | |||
| 34e120bcbb | |||
| 779f8df6a6 | |||
| 62abb453d3 | |||
| 735a6e7651 | |||
| e5ddca1c8b | |||
| 214827a594 | |||
| fd0e1aac72 | |||
| 678e0bd9cc | |||
| 8ccd14a0d4 | |||
| 6c611c852e | |||
| df9020dfa3 | |||
| e266530c7d | |||
| 879b7d3fbf | |||
| 9f36483bf4 | |||
| 7be314c456 | |||
| 9001b34146 | |||
| 861202b56c | |||
| 9d63dcc3f9 |
@@ -42,19 +42,16 @@ def _setup_logging() -> None:
|
||||
|
||||
def _load_env() -> None:
|
||||
"""Load .env from HERMES_HOME (default ``~/.hermes``)."""
|
||||
from dotenv import load_dotenv
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
env_file = hermes_home / ".env"
|
||||
if env_file.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=env_file, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=env_file, encoding="latin-1")
|
||||
logging.getLogger(__name__).info("Loaded env from %s", env_file)
|
||||
loaded = load_hermes_dotenv(hermes_home=hermes_home)
|
||||
if loaded:
|
||||
for env_file in loaded:
|
||||
logging.getLogger(__name__).info("Loaded env from %s", env_file)
|
||||
else:
|
||||
logging.getLogger(__name__).info(
|
||||
"No .env found at %s, using system env", env_file
|
||||
"No .env found at %s, using system env", hermes_home / ".env"
|
||||
)
|
||||
|
||||
|
||||
|
||||
+68
-11
@@ -497,6 +497,66 @@ def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
|
||||
return result
|
||||
|
||||
|
||||
def _image_source_from_openai_url(url: str) -> Dict[str, str]:
|
||||
"""Convert an OpenAI-style image URL/data URL into Anthropic image source."""
|
||||
url = str(url or "").strip()
|
||||
if not url:
|
||||
return {"type": "url", "url": ""}
|
||||
|
||||
if url.startswith("data:"):
|
||||
header, _, data = url.partition(",")
|
||||
media_type = "image/jpeg"
|
||||
if header.startswith("data:"):
|
||||
mime_part = header[len("data:"):].split(";", 1)[0].strip()
|
||||
if mime_part.startswith("image/"):
|
||||
media_type = mime_part
|
||||
return {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
return {"type": "url", "url": url}
|
||||
|
||||
|
||||
def _convert_content_part_to_anthropic(part: Any) -> Optional[Dict[str, Any]]:
|
||||
"""Convert a single OpenAI-style content part to Anthropic format."""
|
||||
if part is None:
|
||||
return None
|
||||
if isinstance(part, str):
|
||||
return {"type": "text", "text": part}
|
||||
if not isinstance(part, dict):
|
||||
return {"type": "text", "text": str(part)}
|
||||
|
||||
ptype = part.get("type")
|
||||
|
||||
if ptype == "input_text":
|
||||
block: Dict[str, Any] = {"type": "text", "text": part.get("text", "")}
|
||||
elif ptype in {"image_url", "input_image"}:
|
||||
image_value = part.get("image_url", {})
|
||||
url = image_value.get("url", "") if isinstance(image_value, dict) else str(image_value or "")
|
||||
block = {"type": "image", "source": _image_source_from_openai_url(url)}
|
||||
else:
|
||||
block = dict(part)
|
||||
|
||||
if isinstance(part.get("cache_control"), dict) and "cache_control" not in block:
|
||||
block["cache_control"] = dict(part["cache_control"])
|
||||
return block
|
||||
|
||||
|
||||
def _convert_content_to_anthropic(content: Any) -> Any:
|
||||
"""Convert OpenAI-style multimodal content arrays to Anthropic blocks."""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
|
||||
converted = []
|
||||
for part in content:
|
||||
block = _convert_content_part_to_anthropic(part)
|
||||
if block is not None:
|
||||
converted.append(block)
|
||||
return converted
|
||||
|
||||
|
||||
def convert_messages_to_anthropic(
|
||||
messages: List[Dict],
|
||||
) -> Tuple[Optional[Any], List[Dict]]:
|
||||
@@ -533,11 +593,9 @@ def convert_messages_to_anthropic(
|
||||
blocks = []
|
||||
if content:
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
blocks.append(dict(part))
|
||||
elif part is not None:
|
||||
blocks.append({"type": "text", "text": str(part)})
|
||||
converted_content = _convert_content_to_anthropic(content)
|
||||
if isinstance(converted_content, list):
|
||||
blocks.extend(converted_content)
|
||||
else:
|
||||
blocks.append({"type": "text", "text": str(content)})
|
||||
for tc in m.get("tool_calls", []):
|
||||
@@ -587,12 +645,11 @@ def convert_messages_to_anthropic(
|
||||
|
||||
# Regular user message
|
||||
if isinstance(content, list):
|
||||
converted_blocks = []
|
||||
for part in content:
|
||||
converted = _convert_user_content_part_to_anthropic(part)
|
||||
if converted is not None:
|
||||
converted_blocks.append(converted)
|
||||
result.append({"role": "user", "content": converted_blocks or [{"type": "text", "text": ""}]})
|
||||
converted_blocks = _convert_content_to_anthropic(content)
|
||||
result.append({
|
||||
"role": "user",
|
||||
"content": converted_blocks or [{"type": "text", "text": ""}],
|
||||
})
|
||||
else:
|
||||
result.append({"role": "user", "content": content})
|
||||
|
||||
|
||||
@@ -83,7 +83,10 @@ _AUTH_JSON_PATH = get_hermes_home() / "auth.json"
|
||||
|
||||
# Codex fallback: uses the Responses API (the only endpoint the Codex
|
||||
# OAuth token can access) with a fast model for auxiliary tasks.
|
||||
_CODEX_AUX_MODEL = "gpt-5.3-codex"
|
||||
# ChatGPT-backed Codex accounts currently reject gpt-5.3-codex for these
|
||||
# auxiliary flows, while gpt-5.2-codex remains broadly available and supports
|
||||
# vision via Responses.
|
||||
_CODEX_AUX_MODEL = "gpt-5.2-codex"
|
||||
_CODEX_AUX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
|
||||
@@ -61,23 +61,14 @@ import queue
|
||||
_COMMAND_SPINNER_FRAMES = ("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏")
|
||||
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback
|
||||
from dotenv import load_dotenv
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
_hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
_user_env = _hermes_home / ".env"
|
||||
_project_env = Path(__file__).parent / '.env'
|
||||
if _user_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="latin-1")
|
||||
elif _project_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_project_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_project_env, encoding="latin-1")
|
||||
load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
|
||||
# Point mini-swe-agent at ~/.hermes/ so it shares our config
|
||||
os.environ.setdefault("MSWEA_GLOBAL_CONFIG_DIR", str(_hermes_home))
|
||||
|
||||
@@ -288,6 +288,7 @@ class MessageEvent:
|
||||
message_id: Optional[str] = None
|
||||
|
||||
# Media attachments
|
||||
# media_urls: local file paths (for vision tool access)
|
||||
media_urls: List[str] = field(default_factory=list)
|
||||
media_types: List[str] = field(default_factory=list)
|
||||
|
||||
@@ -355,6 +356,10 @@ class BasePlatformAdapter(ABC):
|
||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||
# Background message-processing tasks spawned by handle_message().
|
||||
# Gateway shutdown cancels these so an old gateway instance doesn't keep
|
||||
# working on a task after --replace or manual restarts.
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||
self._auto_tts_disabled_chats: set = set()
|
||||
|
||||
@@ -751,7 +756,25 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
# Check if there's already an active handler for this session
|
||||
if session_key in self._active_sessions:
|
||||
# Store this as a pending message - it will interrupt the running agent
|
||||
# Special case: photo bursts/albums frequently arrive as multiple near-
|
||||
# simultaneous messages. Queue them without interrupting the active run,
|
||||
# then process them immediately after the current task finishes.
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
print(f"[{self.name}] 🖼️ Queuing photo follow-up for session {session_key} without interrupt")
|
||||
existing = self._pending_messages.get(session_key)
|
||||
if existing and existing.message_type == MessageType.PHOTO:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
else:
|
||||
self._pending_messages[session_key] = event
|
||||
return # Don't interrupt now - will run after current task completes
|
||||
|
||||
# Default behavior for non-photo follow-ups: interrupt the running agent
|
||||
print(f"[{self.name}] ⚡ New message while session {session_key} is active - triggering interrupt")
|
||||
self._pending_messages[session_key] = event
|
||||
# Signal the interrupt (the processing task checks this)
|
||||
@@ -759,7 +782,15 @@ class BasePlatformAdapter(ABC):
|
||||
return # Don't process now - will be handled after current task finishes
|
||||
|
||||
# Spawn background task to process this message
|
||||
asyncio.create_task(self._process_message_background(event, session_key))
|
||||
task = asyncio.create_task(self._process_message_background(event, session_key))
|
||||
try:
|
||||
self._background_tasks.add(task)
|
||||
except TypeError:
|
||||
# Some tests stub create_task() with lightweight sentinels that are not
|
||||
# hashable and do not support lifecycle callbacks.
|
||||
return
|
||||
if hasattr(task, "add_done_callback"):
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
@staticmethod
|
||||
def _get_human_delay() -> float:
|
||||
@@ -969,6 +1000,21 @@ class BasePlatformAdapter(ABC):
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
|
||||
async def cancel_background_tasks(self) -> None:
|
||||
"""Cancel any in-flight background message-processing tasks.
|
||||
|
||||
Used during gateway shutdown/replacement so active sessions from the old
|
||||
process do not keep running after adapters are being torn down.
|
||||
"""
|
||||
tasks = [task for task in self._background_tasks if not task.done()]
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
self._pending_messages.clear()
|
||||
self._active_sessions.clear()
|
||||
|
||||
def has_pending_interrupt(self, session_key: str) -> bool:
|
||||
"""Check if there's a pending interrupt for a session."""
|
||||
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
||||
|
||||
@@ -87,8 +87,9 @@ class VoiceReceiver:
|
||||
SAMPLE_RATE = 48000 # Discord native rate
|
||||
CHANNELS = 2 # Discord sends stereo
|
||||
|
||||
def __init__(self, voice_client):
|
||||
def __init__(self, voice_client, allowed_user_ids: set = None):
|
||||
self._vc = voice_client
|
||||
self._allowed_user_ids = allowed_user_ids or set()
|
||||
self._running = False
|
||||
|
||||
# Decryption
|
||||
@@ -274,19 +275,21 @@ class VoiceReceiver:
|
||||
if self._dave_session:
|
||||
with self._lock:
|
||||
user_id = self._ssrc_to_user.get(ssrc, 0)
|
||||
if user_id == 0:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE skip: unknown user for ssrc=%d", ssrc)
|
||||
return # unknown user, can't DAVE-decrypt
|
||||
try:
|
||||
import davey
|
||||
decrypted = self._dave_session.decrypt(
|
||||
user_id, davey.MediaType.audio, decrypted
|
||||
)
|
||||
except Exception as e:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e)
|
||||
return
|
||||
if user_id:
|
||||
try:
|
||||
import davey
|
||||
decrypted = self._dave_session.decrypt(
|
||||
user_id, davey.MediaType.audio, decrypted
|
||||
)
|
||||
except Exception as e:
|
||||
# Unencrypted passthrough — use NaCl-decrypted data as-is
|
||||
if "Unencrypted" not in str(e):
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e)
|
||||
return
|
||||
# If SSRC unknown (no SPEAKING event yet), skip DAVE and try
|
||||
# Opus decode directly — audio may be in passthrough mode.
|
||||
# Buffer will get a user_id when SPEAKING event arrives later.
|
||||
|
||||
# --- Opus decode -> PCM ---
|
||||
try:
|
||||
@@ -304,6 +307,32 @@ class VoiceReceiver:
|
||||
# Silence detection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _infer_user_for_ssrc(self, ssrc: int) -> int:
|
||||
"""Try to infer user_id for an unmapped SSRC.
|
||||
|
||||
When the bot rejoins a voice channel, Discord may not resend
|
||||
SPEAKING events for users already speaking. If exactly one
|
||||
allowed user is in the channel, map the SSRC to them.
|
||||
"""
|
||||
try:
|
||||
channel = self._vc.channel
|
||||
if not channel:
|
||||
return 0
|
||||
bot_id = self._vc.user.id if self._vc.user else 0
|
||||
allowed = self._allowed_user_ids
|
||||
candidates = [
|
||||
m.id for m in channel.members
|
||||
if m.id != bot_id and (not allowed or str(m.id) in allowed)
|
||||
]
|
||||
if len(candidates) == 1:
|
||||
uid = candidates[0]
|
||||
self._ssrc_to_user[ssrc] = uid
|
||||
logger.info("Auto-mapped ssrc=%d -> user=%d (sole allowed member)", ssrc, uid)
|
||||
return uid
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
def check_silence(self) -> list:
|
||||
"""Return list of (user_id, pcm_bytes) for completed utterances."""
|
||||
now = time.monotonic()
|
||||
@@ -322,6 +351,10 @@ class VoiceReceiver:
|
||||
|
||||
if silence_duration >= self.SILENCE_THRESHOLD and buf_duration >= self.MIN_SPEECH_DURATION:
|
||||
user_id = ssrc_user_map.get(ssrc, 0)
|
||||
if not user_id:
|
||||
# SSRC not mapped (SPEAKING event missing after bot rejoin).
|
||||
# Infer from allowed users in the voice channel.
|
||||
user_id = self._infer_user_for_ssrc(ssrc)
|
||||
if user_id:
|
||||
completed.append((user_id, bytes(buf)))
|
||||
self._buffers[ssrc] = bytearray()
|
||||
@@ -400,6 +433,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._voice_listen_tasks: Dict[int, asyncio.Task] = {} # guild_id -> listen loop
|
||||
self._voice_input_callback: Optional[Callable] = None # set by run.py
|
||||
self._on_voice_disconnect: Optional[Callable] = None # set by run.py
|
||||
# Track threads where the bot has participated so follow-up messages
|
||||
# in those threads don't require @mention.
|
||||
self._bot_participated_threads: set = set()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
@@ -580,7 +616,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
"""Send a message to a Discord channel."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Get the channel
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
@@ -695,13 +731,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
) -> SendResult:
|
||||
"""Play auto-TTS audio.
|
||||
|
||||
When the bot is in a voice channel for this chat's guild, skip the
|
||||
file attachment — the gateway runner plays audio in the VC instead.
|
||||
When the bot is in a voice channel for this chat's guild, play
|
||||
directly in the VC instead of sending as a file attachment.
|
||||
"""
|
||||
for gid, text_ch_id in self._voice_text_channels.items():
|
||||
if str(text_ch_id) == str(chat_id) and self.is_in_voice_channel(gid):
|
||||
logger.debug("[%s] Skipping play_tts for %s — VC playback handled by runner", self.name, chat_id)
|
||||
return SendResult(success=True)
|
||||
logger.info("[%s] Playing TTS in voice channel (guild=%d)", self.name, gid)
|
||||
success = await self.play_in_voice_channel(gid, audio_path)
|
||||
return SendResult(success=success)
|
||||
return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs)
|
||||
|
||||
async def send_voice(
|
||||
@@ -805,7 +842,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
# Start voice receiver (Phase 2: listen to users)
|
||||
try:
|
||||
receiver = VoiceReceiver(vc)
|
||||
receiver = VoiceReceiver(vc, allowed_user_ids=self._allowed_user_ids)
|
||||
receiver.start()
|
||||
self._voice_receivers[guild_id] = receiver
|
||||
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
|
||||
@@ -1001,14 +1038,32 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Voice listening (Phase 2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# UDP keepalive interval in seconds — prevents Discord from dropping
|
||||
# the UDP route after ~60s of silence.
|
||||
_KEEPALIVE_INTERVAL = 15
|
||||
|
||||
async def _voice_listen_loop(self, guild_id: int):
|
||||
"""Periodically check for completed utterances and process them."""
|
||||
receiver = self._voice_receivers.get(guild_id)
|
||||
if not receiver:
|
||||
return
|
||||
last_keepalive = time.monotonic()
|
||||
try:
|
||||
while receiver._running:
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Send periodic UDP keepalive to prevent Discord from
|
||||
# dropping the UDP session after ~60s of silence.
|
||||
now = time.monotonic()
|
||||
if now - last_keepalive >= self._KEEPALIVE_INTERVAL:
|
||||
last_keepalive = now
|
||||
try:
|
||||
vc = self._voice_clients.get(guild_id)
|
||||
if vc and vc.is_connected():
|
||||
vc._connection.send_packet(b'\xf8\xff\xfe')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
completed = receiver.check_silence()
|
||||
for user_id, pcm_data in completed:
|
||||
if not self._is_allowed_user(str(user_id)):
|
||||
@@ -1746,14 +1801,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
# UNLESS the channel is in the free-response list.
|
||||
# UNLESS the channel is in the free-response list or the message is
|
||||
# in a thread where the bot has already participated.
|
||||
#
|
||||
# Config:
|
||||
# DISCORD_FREE_RESPONSE_CHANNELS: Comma-separated channel IDs where the
|
||||
# bot responds to every message without needing a mention.
|
||||
# DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement
|
||||
# globally (all channels become free-response). Default: "true".
|
||||
# Can also be set via discord.require_mention in config.yaml.
|
||||
# Config (all settable via discord.* in config.yaml):
|
||||
# discord.require_mention: Require @mention in server channels (default: true)
|
||||
# discord.free_response_channels: Channel IDs where bot responds without mention
|
||||
# discord.auto_thread: Auto-create thread on @mention in channels (default: true)
|
||||
|
||||
thread_id = None
|
||||
parent_channel_id = None
|
||||
@@ -1772,7 +1826,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
is_free_channel = bool(channel_ids & free_channels)
|
||||
|
||||
if require_mention and not is_free_channel:
|
||||
# Skip the mention check if the message is in a thread where
|
||||
# the bot has previously participated (auto-created or replied in).
|
||||
in_bot_thread = is_thread and thread_id in self._bot_participated_threads
|
||||
|
||||
if require_mention and not is_free_channel and not in_bot_thread:
|
||||
if self._client.user not in message.mentions:
|
||||
return
|
||||
|
||||
@@ -1781,17 +1839,18 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip()
|
||||
|
||||
# Auto-thread: when enabled, automatically create a thread for every
|
||||
# new message in a text channel so each conversation is isolated.
|
||||
# @mention in a text channel so each conversation is isolated (like Slack).
|
||||
# Messages already inside threads or DMs are unaffected.
|
||||
auto_threaded_channel = None
|
||||
if not is_thread and not isinstance(message.channel, discord.DMChannel):
|
||||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "").lower() in ("true", "1", "yes")
|
||||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in ("true", "1", "yes")
|
||||
if auto_thread:
|
||||
thread = await self._auto_create_thread(message)
|
||||
if thread:
|
||||
is_thread = True
|
||||
thread_id = str(thread.id)
|
||||
auto_threaded_channel = thread
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
@@ -1891,7 +1950,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
|
||||
timestamp=message.created_at,
|
||||
)
|
||||
|
||||
|
||||
# Track thread participation so the bot won't require @mention for
|
||||
# follow-up messages in threads it has already engaged in.
|
||||
if thread_id:
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
|
||||
@@ -111,6 +111,11 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
super().__init__(config, Platform.TELEGRAM)
|
||||
self._app: Optional[Application] = None
|
||||
self._bot: Optional[Bot] = None
|
||||
# Buffer rapid/album photo updates so Telegram image bursts are handled
|
||||
# as a single MessageEvent instead of self-interrupting multiple turns.
|
||||
self._media_batch_delay_seconds = float(os.getenv("HERMES_TELEGRAM_MEDIA_BATCH_DELAY_SECONDS", "0.8"))
|
||||
self._pending_photo_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_photo_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._media_group_events: Dict[str, MessageEvent] = {}
|
||||
self._media_group_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
@@ -289,13 +294,19 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True)
|
||||
|
||||
|
||||
for task in self._pending_photo_batch_tasks.values():
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
self._pending_photo_batch_tasks.clear()
|
||||
self._pending_photo_batches.clear()
|
||||
|
||||
self._mark_disconnected()
|
||||
self._app = None
|
||||
self._bot = None
|
||||
self._token_lock_identity = None
|
||||
logger.info("[%s] Disconnected from Telegram", self.name)
|
||||
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -807,6 +818,49 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
event.text = "\n".join(parts)
|
||||
await self.handle_message(event)
|
||||
|
||||
def _photo_batch_key(self, event: MessageEvent, msg: Message) -> str:
|
||||
"""Return a batching key for Telegram photos/albums."""
|
||||
from gateway.session import build_session_key
|
||||
session_key = build_session_key(event.source)
|
||||
media_group_id = getattr(msg, "media_group_id", None)
|
||||
if media_group_id:
|
||||
return f"{session_key}:album:{media_group_id}"
|
||||
return f"{session_key}:photo-burst"
|
||||
|
||||
async def _flush_photo_batch(self, batch_key: str) -> None:
|
||||
"""Send a buffered photo burst/album as a single MessageEvent."""
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
await asyncio.sleep(self._media_batch_delay_seconds)
|
||||
event = self._pending_photo_batches.pop(batch_key, None)
|
||||
if not event:
|
||||
return
|
||||
logger.info("[Telegram] Flushing photo batch %s with %d image(s)", batch_key, len(event.media_urls))
|
||||
await self.handle_message(event)
|
||||
finally:
|
||||
if self._pending_photo_batch_tasks.get(batch_key) is current_task:
|
||||
self._pending_photo_batch_tasks.pop(batch_key, None)
|
||||
|
||||
def _enqueue_photo_event(self, batch_key: str, event: MessageEvent) -> None:
|
||||
"""Merge photo events into a pending batch and schedule flush."""
|
||||
existing = self._pending_photo_batches.get(batch_key)
|
||||
if existing is None:
|
||||
self._pending_photo_batches[batch_key] = event
|
||||
else:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
|
||||
prior_task = self._pending_photo_batch_tasks.get(batch_key)
|
||||
if prior_task and not prior_task.done():
|
||||
prior_task.cancel()
|
||||
|
||||
self._pending_photo_batch_tasks[batch_key] = asyncio.create_task(self._flush_photo_batch(batch_key))
|
||||
|
||||
async def _handle_media_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming media messages, downloading images to local cache."""
|
||||
if not update.message:
|
||||
@@ -858,14 +912,22 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
if file_obj.file_path.lower().endswith(candidate):
|
||||
ext = candidate
|
||||
break
|
||||
# Save to cache and populate media_urls with the local path
|
||||
# Save to local cache (for vision tool access)
|
||||
cached_path = cache_image_from_bytes(bytes(image_bytes), ext=ext)
|
||||
event.media_urls = [cached_path]
|
||||
event.media_types = [f"image/{ext.lstrip('.')}"]
|
||||
event.media_types = [f"image/{ext.lstrip('.')}" ]
|
||||
logger.info("[Telegram] Cached user photo at %s", cached_path)
|
||||
media_group_id = getattr(msg, "media_group_id", None)
|
||||
if media_group_id:
|
||||
await self._queue_media_group_event(str(media_group_id), event)
|
||||
else:
|
||||
batch_key = self._photo_batch_key(event, msg)
|
||||
self._enqueue_photo_event(batch_key, event)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("[Telegram] Failed to cache photo: %s", e, exc_info=True)
|
||||
|
||||
|
||||
# Download voice/audio messages to cache for STT transcription
|
||||
if msg.voice:
|
||||
try:
|
||||
|
||||
+73
-28
@@ -35,16 +35,12 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
# Resolve Hermes home directory (respects HERMES_HOME override)
|
||||
_hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
# Load environment variables from ~/.hermes/.env first
|
||||
from dotenv import load_dotenv
|
||||
# Load environment variables from ~/.hermes/.env first.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
from dotenv import load_dotenv # backward-compat for tests that monkeypatch this symbol
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
_env_path = _hermes_home / '.env'
|
||||
if _env_path.exists():
|
||||
try:
|
||||
load_dotenv(_env_path, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(_env_path, encoding="latin-1")
|
||||
# Also try project .env as fallback
|
||||
load_dotenv()
|
||||
load_hermes_dotenv(hermes_home=_hermes_home, project_env=Path(__file__).resolve().parents[1] / '.env')
|
||||
|
||||
# Bridge config.yaml values into the environment so os.getenv() picks them up.
|
||||
# config.yaml is authoritative for terminal settings — overrides .env.
|
||||
@@ -900,8 +896,19 @@ class GatewayRunner:
|
||||
"""Stop the gateway and disconnect all adapters."""
|
||||
logger.info("Stopping gateway...")
|
||||
self._running = False
|
||||
|
||||
|
||||
for session_key, agent in list(self._running_agents.items()):
|
||||
try:
|
||||
agent.interrupt("Gateway shutting down")
|
||||
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
|
||||
except Exception as e:
|
||||
logger.debug("Failed interrupting agent during shutdown: %s", e)
|
||||
|
||||
for platform, adapter in list(self.adapters.items()):
|
||||
try:
|
||||
await adapter.cancel_background_tasks()
|
||||
except Exception as e:
|
||||
logger.debug("✗ %s background-task cancel error: %s", platform.value, e)
|
||||
try:
|
||||
await adapter.disconnect()
|
||||
logger.info("✓ %s disconnected", platform.value)
|
||||
@@ -909,6 +916,9 @@ class GatewayRunner:
|
||||
logger.error("✗ %s disconnect error: %s", platform.value, e)
|
||||
|
||||
self.adapters.clear()
|
||||
self._running_agents.clear()
|
||||
self._pending_messages.clear()
|
||||
self._pending_approvals.clear()
|
||||
self._shutdown_all_gateway_honcho()
|
||||
self._shutdown_event.set()
|
||||
|
||||
@@ -1095,11 +1105,36 @@ class GatewayRunner:
|
||||
)
|
||||
return None
|
||||
|
||||
# PRIORITY: If an agent is already running for this session, interrupt it
|
||||
# immediately. This is before command parsing to minimize latency -- the
|
||||
# user's "stop" message reaches the agent as fast as possible.
|
||||
# PRIORITY handling when an agent is already running for this session.
|
||||
# Default behavior is to interrupt immediately so user text/stop messages
|
||||
# are handled with minimal latency.
|
||||
#
|
||||
# Special case: Telegram/photo bursts often arrive as multiple near-
|
||||
# simultaneous updates. Do NOT interrupt for photo-only follow-ups here;
|
||||
# let the adapter-level batching/queueing logic absorb them.
|
||||
_quick_key = build_session_key(source)
|
||||
if _quick_key in self._running_agents:
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
# Reuse adapter queue semantics so photo bursts merge cleanly.
|
||||
if _quick_key in adapter._pending_messages:
|
||||
existing = adapter._pending_messages[_quick_key]
|
||||
if getattr(existing, "message_type", None) == MessageType.PHOTO:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
else:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
else:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
return None
|
||||
|
||||
running_agent = self._running_agents[_quick_key]
|
||||
logger.debug("PRIORITY interrupt for session %s", _quick_key[:20])
|
||||
running_agent.interrupt(event.text)
|
||||
@@ -2396,6 +2431,13 @@ class GatewayRunner:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to join voice channel: %s", e)
|
||||
adapter._voice_input_callback = None
|
||||
err_lower = str(e).lower()
|
||||
if "pynacl" in err_lower or "nacl" in err_lower or "davey" in err_lower:
|
||||
return (
|
||||
"Voice dependencies are missing (PyNaCl / davey). "
|
||||
"Install or reinstall Hermes with the messaging extra, e.g. "
|
||||
"`pip install hermes-agent[messaging]`."
|
||||
)
|
||||
return f"Failed to join voice channel: {e}"
|
||||
|
||||
if success:
|
||||
@@ -2536,18 +2578,9 @@ class GatewayRunner:
|
||||
if has_agent_tts:
|
||||
return False
|
||||
|
||||
# Dedup: base adapter auto-TTS already handles voice input.
|
||||
# Exception: Discord voice channel — play_tts override is a no-op,
|
||||
# so the runner must handle VC playback.
|
||||
skip_double = is_voice_input
|
||||
if skip_double:
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if (guild_id and adapter
|
||||
and hasattr(adapter, "is_in_voice_channel")
|
||||
and adapter.is_in_voice_channel(guild_id)):
|
||||
skip_double = False
|
||||
if skip_double:
|
||||
# Dedup: base adapter auto-TTS already handles voice input
|
||||
# (play_tts plays in VC when connected, so runner can skip).
|
||||
if is_voice_input:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -3469,10 +3502,12 @@ class GatewayRunner:
|
||||
os.environ["HERMES_SESSION_CHAT_ID"] = context.source.chat_id
|
||||
if context.source.chat_name:
|
||||
os.environ["HERMES_SESSION_CHAT_NAME"] = context.source.chat_name
|
||||
if context.source.thread_id:
|
||||
os.environ["HERMES_SESSION_THREAD_ID"] = str(context.source.thread_id)
|
||||
|
||||
def _clear_session_env(self) -> None:
|
||||
"""Clear session environment variables."""
|
||||
for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME"]:
|
||||
for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME", "HERMES_SESSION_THREAD_ID"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
@@ -3490,9 +3525,13 @@ class GatewayRunner:
|
||||
1. Immediately understand what the user sent (no extra tool call).
|
||||
2. Re-examine the image with vision_analyze if it needs more detail.
|
||||
|
||||
Athabasca persistence should happen through Athabasca's own POST
|
||||
/api/uploads flow, using the returned asset.publicUrl rather than local
|
||||
cache paths.
|
||||
|
||||
Args:
|
||||
user_text: The user's original caption / message text.
|
||||
image_paths: List of local file paths to cached images.
|
||||
user_text: The user's original caption / message text.
|
||||
image_paths: List of local file paths to cached images.
|
||||
|
||||
Returns:
|
||||
The enriched message string with vision descriptions prepended.
|
||||
@@ -3517,10 +3556,16 @@ class GatewayRunner:
|
||||
result = _json.loads(result_json)
|
||||
if result.get("success"):
|
||||
description = result.get("analysis", "")
|
||||
athabasca_note = (
|
||||
"\n[If this image needs to persist in Athabasca state, upload the cached file "
|
||||
"through Athabasca POST /api/uploads and use the returned asset.publicUrl. "
|
||||
"Do not store the local cache path as the canonical imageUrl.]"
|
||||
)
|
||||
enriched_parts.append(
|
||||
f"[The user sent an image~ Here's what I can see:\n{description}]\n"
|
||||
f"[If you need a closer look, use vision_analyze with "
|
||||
f"image_url: {path} ~]"
|
||||
f"{athabasca_note}"
|
||||
)
|
||||
else:
|
||||
enriched_parts.append(
|
||||
|
||||
+17
-10
@@ -321,25 +321,32 @@ def build_session_key(source: SessionSource) -> str:
|
||||
This is the single source of truth for session key construction.
|
||||
|
||||
DM rules:
|
||||
- WhatsApp DMs include chat_id (multi-user support).
|
||||
- Other DMs include thread_id when present (e.g. Slack threaded DMs),
|
||||
so each DM thread gets its own session while top-level DMs share one.
|
||||
- Without thread_id or chat_id, all DMs share a single session.
|
||||
- DMs include chat_id when present, so each private conversation is isolated.
|
||||
- thread_id further differentiates threaded DMs within the same DM chat.
|
||||
- Without chat_id, thread_id is used as a best-effort fallback.
|
||||
- Without thread_id or chat_id, DMs share a single session.
|
||||
|
||||
Group/channel rules:
|
||||
- thread_id differentiates threads within a channel.
|
||||
- Without thread_id, all messages in a channel share one session.
|
||||
- chat_id identifies the parent group/channel.
|
||||
- thread_id differentiates threads within that parent chat.
|
||||
- Without identifiers, messages fall back to one session per platform/chat_type.
|
||||
"""
|
||||
platform = source.platform.value
|
||||
if source.chat_type == "dm":
|
||||
if source.chat_id:
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:dm:{source.chat_id}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:dm:{source.chat_id}"
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:dm:{source.thread_id}"
|
||||
if platform == "whatsapp" and source.chat_id:
|
||||
return f"agent:main:{platform}:dm:{source.chat_id}"
|
||||
return f"agent:main:{platform}:dm"
|
||||
if source.chat_id:
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}"
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}"
|
||||
|
||||
|
||||
class SessionStore:
|
||||
|
||||
@@ -280,6 +280,7 @@ DEFAULT_CONFIG = {
|
||||
"discord": {
|
||||
"require_mention": True, # Require @mention to respond in server channels
|
||||
"free_response_channels": "", # Comma-separated channel IDs where bot responds without mention
|
||||
"auto_thread": True, # Auto-create threads on @mention in channels (like Slack)
|
||||
},
|
||||
|
||||
# Permanently allowed dangerous command patterns (added via "always" approval)
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Helpers for loading Hermes .env files consistently across entrypoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
def _load_dotenv_with_fallback(path: Path, *, override: bool) -> None:
|
||||
try:
|
||||
load_dotenv(dotenv_path=path, override=override, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=path, override=override, encoding="latin-1")
|
||||
|
||||
|
||||
def load_hermes_dotenv(
|
||||
*,
|
||||
hermes_home: str | os.PathLike | None = None,
|
||||
project_env: str | os.PathLike | None = None,
|
||||
) -> list[Path]:
|
||||
"""Load Hermes environment files with user config taking precedence.
|
||||
|
||||
Behavior:
|
||||
- `~/.hermes/.env` overrides stale shell-exported values when present.
|
||||
- project `.env` acts as a dev fallback and only fills missing values when
|
||||
the user env exists.
|
||||
- if no user env exists, the project `.env` also overrides stale shell vars.
|
||||
"""
|
||||
loaded: list[Path] = []
|
||||
|
||||
home_path = Path(hermes_home or os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
user_env = home_path / ".env"
|
||||
project_env_path = Path(project_env) if project_env else None
|
||||
|
||||
if user_env.exists():
|
||||
_load_dotenv_with_fallback(user_env, override=True)
|
||||
loaded.append(user_env)
|
||||
|
||||
if project_env_path and project_env_path.exists():
|
||||
_load_dotenv_with_fallback(project_env_path, override=not loaded)
|
||||
loaded.append(project_env_path)
|
||||
|
||||
return loaded
|
||||
+23
-16
@@ -54,16 +54,11 @@ from typing import Optional
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback
|
||||
from dotenv import load_dotenv
|
||||
from hermes_cli.config import get_env_path, get_hermes_home
|
||||
_user_env = get_env_path()
|
||||
if _user_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="latin-1")
|
||||
load_dotenv(dotenv_path=PROJECT_ROOT / '.env', override=False)
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
from hermes_cli.config import get_hermes_home
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
load_hermes_dotenv(project_env=PROJECT_ROOT / '.env')
|
||||
|
||||
# Point mini-swe-agent at ~/.hermes/ so it shares our config
|
||||
os.environ.setdefault("MSWEA_GLOBAL_CONFIG_DIR", str(get_hermes_home()))
|
||||
@@ -3103,7 +3098,11 @@ For more help on a command:
|
||||
|
||||
elif action == "export":
|
||||
if args.session_id:
|
||||
data = db.export_session(args.session_id)
|
||||
resolved_session_id = db.resolve_session_id(args.session_id)
|
||||
if not resolved_session_id:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
return
|
||||
data = db.export_session(resolved_session_id)
|
||||
if not data:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
return
|
||||
@@ -3118,13 +3117,17 @@ For more help on a command:
|
||||
print(f"Exported {len(sessions)} sessions to {args.output}")
|
||||
|
||||
elif action == "delete":
|
||||
resolved_session_id = db.resolve_session_id(args.session_id)
|
||||
if not resolved_session_id:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
return
|
||||
if not args.yes:
|
||||
confirm = input(f"Delete session '{args.session_id}' and all its messages? [y/N] ")
|
||||
confirm = input(f"Delete session '{resolved_session_id}' and all its messages? [y/N] ")
|
||||
if confirm.lower() not in ("y", "yes"):
|
||||
print("Cancelled.")
|
||||
return
|
||||
if db.delete_session(args.session_id):
|
||||
print(f"Deleted session '{args.session_id}'.")
|
||||
if db.delete_session(resolved_session_id):
|
||||
print(f"Deleted session '{resolved_session_id}'.")
|
||||
else:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
|
||||
@@ -3140,10 +3143,14 @@ For more help on a command:
|
||||
print(f"Pruned {count} session(s).")
|
||||
|
||||
elif action == "rename":
|
||||
resolved_session_id = db.resolve_session_id(args.session_id)
|
||||
if not resolved_session_id:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
return
|
||||
title = " ".join(args.title)
|
||||
try:
|
||||
if db.set_session_title(args.session_id, title):
|
||||
print(f"Session '{args.session_id}' renamed to: {title}")
|
||||
if db.set_session_title(resolved_session_id, title):
|
||||
print(f"Session '{resolved_session_id}' renamed to: {title}")
|
||||
else:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
except ValueError as e:
|
||||
|
||||
@@ -354,9 +354,29 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
|
||||
|
||||
def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[str]):
|
||||
"""Save the selected toolset keys for a platform to config."""
|
||||
"""Save the selected toolset keys for a platform to config.
|
||||
|
||||
Preserves any non-configurable toolset entries (like MCP server names)
|
||||
that were already in the config for this platform.
|
||||
"""
|
||||
config.setdefault("platform_toolsets", {})
|
||||
config["platform_toolsets"][platform] = sorted(enabled_toolset_keys)
|
||||
|
||||
# Get the set of all configurable toolset keys
|
||||
configurable_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
|
||||
# Get existing toolsets for this platform
|
||||
existing_toolsets = config.get("platform_toolsets", {}).get(platform, [])
|
||||
if not isinstance(existing_toolsets, list):
|
||||
existing_toolsets = []
|
||||
|
||||
# Preserve any entries that are NOT configurable toolsets (i.e. MCP server names)
|
||||
preserved_entries = {
|
||||
entry for entry in existing_toolsets
|
||||
if entry not in configurable_keys
|
||||
}
|
||||
|
||||
# Merge preserved entries with new enabled toolsets
|
||||
config["platform_toolsets"][platform] = sorted(enabled_toolset_keys | preserved_entries)
|
||||
save_config(config)
|
||||
|
||||
|
||||
|
||||
@@ -249,6 +249,32 @@ class SessionDB:
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]:
|
||||
"""Resolve an exact or uniquely prefixed session ID to the full ID.
|
||||
|
||||
Returns the exact ID when it exists. Otherwise treats the input as a
|
||||
prefix and returns the single matching session ID if the prefix is
|
||||
unambiguous. Returns None for no matches or ambiguous prefixes.
|
||||
"""
|
||||
exact = self.get_session(session_id_or_prefix)
|
||||
if exact:
|
||||
return exact["id"]
|
||||
|
||||
escaped = (
|
||||
session_id_or_prefix
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE id LIKE ? ESCAPE '\\' ORDER BY started_at DESC LIMIT 2",
|
||||
(f"{escaped}%",),
|
||||
)
|
||||
matches = [row["id"] for row in cursor.fetchall()]
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
return None
|
||||
|
||||
# Maximum length for session titles
|
||||
MAX_TITLE_LENGTH = 100
|
||||
|
||||
|
||||
@@ -27,25 +27,16 @@ from pathlib import Path
|
||||
import fire
|
||||
import yaml
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
_hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
_user_env = _hermes_home / ".env"
|
||||
_project_env = Path(__file__).parent / '.env'
|
||||
|
||||
if _user_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="latin-1")
|
||||
print(f"✅ Loaded environment variables from {_user_env}")
|
||||
elif _project_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_project_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_project_env, encoding="latin-1")
|
||||
print(f"✅ Loaded environment variables from {_project_env}")
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
_loaded_env_paths = load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
for _env_path in _loaded_env_paths:
|
||||
print(f"✅ Loaded environment variables from {_env_path}")
|
||||
|
||||
# Set terminal working directory to tinker-atropos submodule
|
||||
# This ensures terminal commands run in the right context for RL work
|
||||
|
||||
+165
-16
@@ -21,6 +21,8 @@ Usage:
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import hashlib
|
||||
@@ -31,6 +33,7 @@ import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import threading
|
||||
import weakref
|
||||
@@ -42,24 +45,16 @@ import fire
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback
|
||||
from dotenv import load_dotenv
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
_hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
_user_env = _hermes_home / ".env"
|
||||
_project_env = Path(__file__).parent / '.env'
|
||||
if _user_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="latin-1")
|
||||
logger.info("Loaded environment variables from %s", _user_env)
|
||||
elif _project_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_project_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_project_env, encoding="latin-1")
|
||||
logger.info("Loaded environment variables from %s", _project_env)
|
||||
_loaded_env_paths = load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
if _loaded_env_paths:
|
||||
for _env_path in _loaded_env_paths:
|
||||
logger.info("Loaded environment variables from %s", _env_path)
|
||||
else:
|
||||
logger.info("No .env file found. Using system environment variables.")
|
||||
|
||||
@@ -504,6 +499,11 @@ class AIAgent:
|
||||
self._persist_user_message_idx = None
|
||||
self._persist_user_message_override = None
|
||||
|
||||
# Cache anthropic image-to-text fallbacks per image payload/URL so a
|
||||
# single tool loop does not repeatedly re-run auxiliary vision on the
|
||||
# same image history.
|
||||
self._anthropic_image_fallback_cache: Dict[str, str] = {}
|
||||
|
||||
# Initialize LLM client via centralized provider router.
|
||||
# The router handles auth resolution, base URL, headers, and
|
||||
# Codex/Anthropic wrapping for all known providers.
|
||||
@@ -3034,13 +3034,156 @@ class AIAgent:
|
||||
|
||||
# ── End provider fallback ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _content_has_image_parts(content: Any) -> bool:
|
||||
if not isinstance(content, list):
|
||||
return False
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") in {"image_url", "input_image"}:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _materialize_data_url_for_vision(image_url: str) -> tuple[str, Optional[Path]]:
|
||||
header, _, data = str(image_url or "").partition(",")
|
||||
mime = "image/jpeg"
|
||||
if header.startswith("data:"):
|
||||
mime_part = header[len("data:"):].split(";", 1)[0].strip()
|
||||
if mime_part.startswith("image/"):
|
||||
mime = mime_part
|
||||
suffix = {
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/jpg": ".jpg",
|
||||
}.get(mime, ".jpg")
|
||||
tmp = tempfile.NamedTemporaryFile(prefix="anthropic_image_", suffix=suffix, delete=False)
|
||||
with tmp:
|
||||
tmp.write(base64.b64decode(data))
|
||||
path = Path(tmp.name)
|
||||
return str(path), path
|
||||
|
||||
def _describe_image_for_anthropic_fallback(self, image_url: str, role: str) -> str:
|
||||
cache_key = hashlib.sha256(str(image_url or "").encode("utf-8")).hexdigest()
|
||||
cached = self._anthropic_image_fallback_cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
role_label = {
|
||||
"assistant": "assistant",
|
||||
"tool": "tool result",
|
||||
}.get(role, "user")
|
||||
analysis_prompt = (
|
||||
"Describe everything visible in this image in thorough detail. "
|
||||
"Include any text, code, UI, data, objects, people, layout, colors, "
|
||||
"and any other notable visual information."
|
||||
)
|
||||
|
||||
vision_source = str(image_url or "")
|
||||
cleanup_path: Optional[Path] = None
|
||||
if vision_source.startswith("data:"):
|
||||
vision_source, cleanup_path = self._materialize_data_url_for_vision(vision_source)
|
||||
|
||||
description = ""
|
||||
try:
|
||||
from tools.vision_tools import vision_analyze_tool
|
||||
|
||||
result_json = asyncio.run(
|
||||
vision_analyze_tool(image_url=vision_source, user_prompt=analysis_prompt)
|
||||
)
|
||||
result = json.loads(result_json) if isinstance(result_json, str) else {}
|
||||
description = (result.get("analysis") or "").strip()
|
||||
except Exception as e:
|
||||
description = f"Image analysis failed: {e}"
|
||||
finally:
|
||||
if cleanup_path and cleanup_path.exists():
|
||||
try:
|
||||
cleanup_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if not description:
|
||||
description = "Image analysis failed."
|
||||
|
||||
note = f"[The {role_label} attached an image. Here's what it contains:\n{description}]"
|
||||
if vision_source and not str(image_url or "").startswith("data:"):
|
||||
note += (
|
||||
f"\n[If you need a closer look, use vision_analyze with image_url: {vision_source}]"
|
||||
)
|
||||
|
||||
self._anthropic_image_fallback_cache[cache_key] = note
|
||||
return note
|
||||
|
||||
def _preprocess_anthropic_content(self, content: Any, role: str) -> Any:
|
||||
if not self._content_has_image_parts(content):
|
||||
return content
|
||||
|
||||
text_parts: List[str] = []
|
||||
image_notes: List[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
if part.strip():
|
||||
text_parts.append(part.strip())
|
||||
continue
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
|
||||
ptype = part.get("type")
|
||||
if ptype in {"text", "input_text"}:
|
||||
text = str(part.get("text", "") or "").strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
continue
|
||||
|
||||
if ptype in {"image_url", "input_image"}:
|
||||
image_data = part.get("image_url", {})
|
||||
image_url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data or "")
|
||||
if image_url:
|
||||
image_notes.append(self._describe_image_for_anthropic_fallback(image_url, role))
|
||||
else:
|
||||
image_notes.append("[An image was attached but no image source was available.]")
|
||||
continue
|
||||
|
||||
text = str(part.get("text", "") or "").strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
|
||||
prefix = "\n\n".join(note for note in image_notes if note).strip()
|
||||
suffix = "\n".join(text for text in text_parts if text).strip()
|
||||
if prefix and suffix:
|
||||
return f"{prefix}\n\n{suffix}"
|
||||
if prefix:
|
||||
return prefix
|
||||
if suffix:
|
||||
return suffix
|
||||
return "[A multimodal message was converted to text for Anthropic compatibility.]"
|
||||
|
||||
def _prepare_anthropic_messages_for_api(self, api_messages: list) -> list:
|
||||
if not any(
|
||||
isinstance(msg, dict) and self._content_has_image_parts(msg.get("content"))
|
||||
for msg in api_messages
|
||||
):
|
||||
return api_messages
|
||||
|
||||
transformed = copy.deepcopy(api_messages)
|
||||
for msg in transformed:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
msg["content"] = self._preprocess_anthropic_content(
|
||||
msg.get("content"),
|
||||
str(msg.get("role", "user") or "user"),
|
||||
)
|
||||
return transformed
|
||||
|
||||
def _build_api_kwargs(self, api_messages: list) -> dict:
|
||||
"""Build the keyword arguments dict for the active API mode."""
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
anthropic_messages = self._prepare_anthropic_messages_for_api(api_messages)
|
||||
return build_anthropic_kwargs(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
messages=anthropic_messages,
|
||||
tools=self.tools,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_config=self.reasoning_config,
|
||||
@@ -5439,6 +5582,12 @@ class AIAgent:
|
||||
invalid_json_args = []
|
||||
for tc in assistant_message.tool_calls:
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, (dict, list)):
|
||||
tc.function.arguments = json.dumps(args)
|
||||
continue
|
||||
if args is not None and not isinstance(args, str):
|
||||
tc.function.arguments = str(args)
|
||||
args = tc.function.arguments
|
||||
# Treat empty/whitespace strings as empty object
|
||||
if not args or not args.strip():
|
||||
tc.function.arguments = "{}"
|
||||
|
||||
Executable
+389
@@ -0,0 +1,389 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Discord Voice Doctor — diagnostic tool for voice channel support.
|
||||
|
||||
Checks all dependencies, configuration, and bot permissions needed
|
||||
for Discord voice mode to work correctly.
|
||||
|
||||
Usage:
|
||||
python scripts/discord-voice-doctor.py
|
||||
.venv/bin/python scripts/discord-voice-doctor.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# Resolve project root
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_ROOT = SCRIPT_DIR.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
ENV_FILE = HERMES_HOME / ".env"
|
||||
|
||||
OK = "\033[92m\u2713\033[0m"
|
||||
FAIL = "\033[91m\u2717\033[0m"
|
||||
WARN = "\033[93m!\033[0m"
|
||||
|
||||
# Track whether discord.py is available for later sections
|
||||
_discord_available = False
|
||||
|
||||
|
||||
def mask(value):
|
||||
"""Mask sensitive value: show only first 4 chars."""
|
||||
if not value or len(value) < 8:
|
||||
return "****"
|
||||
return f"{value[:4]}{'*' * (len(value) - 4)}"
|
||||
|
||||
|
||||
def check(label, ok, detail=""):
|
||||
symbol = OK if ok else FAIL
|
||||
msg = f" {symbol} {label}"
|
||||
if detail:
|
||||
msg += f" ({detail})"
|
||||
print(msg)
|
||||
return ok
|
||||
|
||||
|
||||
def warn(label, detail=""):
|
||||
msg = f" {WARN} {label}"
|
||||
if detail:
|
||||
msg += f" ({detail})"
|
||||
print(msg)
|
||||
|
||||
|
||||
def section(title):
|
||||
print(f"\n\033[1m{title}\033[0m")
|
||||
|
||||
|
||||
def check_packages():
|
||||
"""Check Python package dependencies. Returns True if all critical deps OK."""
|
||||
global _discord_available
|
||||
section("Python Packages")
|
||||
ok = True
|
||||
|
||||
# discord.py
|
||||
try:
|
||||
import discord
|
||||
_discord_available = True
|
||||
check("discord.py", True, f"v{discord.__version__}")
|
||||
except ImportError:
|
||||
check("discord.py", False, "pip install discord.py[voice]")
|
||||
ok = False
|
||||
|
||||
# PyNaCl
|
||||
try:
|
||||
import nacl
|
||||
ver = getattr(nacl, "__version__", "unknown")
|
||||
try:
|
||||
import nacl.secret
|
||||
nacl.secret.Aead(bytes(32))
|
||||
check("PyNaCl", True, f"v{ver}")
|
||||
except (AttributeError, Exception):
|
||||
check("PyNaCl (Aead)", False, f"v{ver} — need >=1.5.0")
|
||||
ok = False
|
||||
except ImportError:
|
||||
check("PyNaCl", False, "pip install PyNaCl>=1.5.0")
|
||||
ok = False
|
||||
|
||||
# davey (DAVE E2EE)
|
||||
try:
|
||||
import davey
|
||||
check("davey (DAVE E2EE)", True, f"v{getattr(davey, '__version__', '?')}")
|
||||
except ImportError:
|
||||
check("davey (DAVE E2EE)", False, "pip install davey")
|
||||
ok = False
|
||||
|
||||
# Optional: local STT
|
||||
try:
|
||||
import faster_whisper
|
||||
check("faster-whisper (local STT)", True)
|
||||
except ImportError:
|
||||
warn("faster-whisper (local STT)", "not installed — local STT unavailable")
|
||||
|
||||
# Optional: TTS providers
|
||||
try:
|
||||
import edge_tts
|
||||
check("edge-tts", True)
|
||||
except ImportError:
|
||||
warn("edge-tts", "not installed — edge TTS unavailable")
|
||||
|
||||
try:
|
||||
import elevenlabs
|
||||
check("elevenlabs SDK", True)
|
||||
except ImportError:
|
||||
warn("elevenlabs SDK", "not installed — premium TTS unavailable")
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def check_system_tools():
|
||||
"""Check system-level tools (opus, ffmpeg). Returns True if all OK."""
|
||||
section("System Tools")
|
||||
ok = True
|
||||
|
||||
# Opus codec
|
||||
if _discord_available:
|
||||
try:
|
||||
import discord
|
||||
opus_loaded = discord.opus.is_loaded()
|
||||
if not opus_loaded:
|
||||
import ctypes.util
|
||||
opus_path = ctypes.util.find_library("opus")
|
||||
if not opus_path:
|
||||
# Platform-specific fallback paths
|
||||
candidates = [
|
||||
"/opt/homebrew/lib/libopus.dylib", # macOS Apple Silicon
|
||||
"/usr/local/lib/libopus.dylib", # macOS Intel
|
||||
"/usr/lib/x86_64-linux-gnu/libopus.so.0", # Debian/Ubuntu x86
|
||||
"/usr/lib/aarch64-linux-gnu/libopus.so.0", # Debian/Ubuntu ARM
|
||||
"/usr/lib/libopus.so", # Arch Linux
|
||||
"/usr/lib64/libopus.so", # RHEL/Fedora
|
||||
]
|
||||
for p in candidates:
|
||||
if os.path.isfile(p):
|
||||
opus_path = p
|
||||
break
|
||||
if opus_path:
|
||||
discord.opus.load_opus(opus_path)
|
||||
opus_loaded = discord.opus.is_loaded()
|
||||
if opus_loaded:
|
||||
check("Opus codec", True)
|
||||
else:
|
||||
check("Opus codec", False, "brew install opus / apt install libopus0")
|
||||
ok = False
|
||||
except Exception as e:
|
||||
check("Opus codec", False, str(e))
|
||||
ok = False
|
||||
else:
|
||||
warn("Opus codec", "skipped — discord.py not installed")
|
||||
|
||||
# ffmpeg
|
||||
ffmpeg_path = shutil.which("ffmpeg")
|
||||
if ffmpeg_path:
|
||||
check("ffmpeg", True, ffmpeg_path)
|
||||
else:
|
||||
check("ffmpeg", False, "brew install ffmpeg / apt install ffmpeg")
|
||||
ok = False
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def check_env_vars():
|
||||
"""Check environment variables. Returns (ok, token, groq_key, eleven_key)."""
|
||||
section("Environment Variables")
|
||||
|
||||
# Load .env
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
if ENV_FILE.exists():
|
||||
load_dotenv(ENV_FILE)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
ok = True
|
||||
|
||||
token = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||
if token:
|
||||
check("DISCORD_BOT_TOKEN", True, mask(token))
|
||||
else:
|
||||
check("DISCORD_BOT_TOKEN", False, "not set")
|
||||
ok = False
|
||||
|
||||
# Allowed users — resolve usernames if possible
|
||||
allowed = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed:
|
||||
users = [u.strip() for u in allowed.split(",") if u.strip()]
|
||||
user_labels = []
|
||||
for uid in users:
|
||||
label = mask(uid)
|
||||
if token and uid.isdigit():
|
||||
try:
|
||||
import requests
|
||||
r = requests.get(
|
||||
f"https://discord.com/api/v10/users/{uid}",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
timeout=3,
|
||||
)
|
||||
if r.status_code == 200:
|
||||
label = f"{r.json().get('username', '?')} ({mask(uid)})"
|
||||
except Exception:
|
||||
pass
|
||||
user_labels.append(label)
|
||||
check("DISCORD_ALLOWED_USERS", True, f"{len(users)} user(s): {', '.join(user_labels)}")
|
||||
else:
|
||||
warn("DISCORD_ALLOWED_USERS", "not set — all users can use voice")
|
||||
|
||||
groq_key = os.getenv("GROQ_API_KEY", "")
|
||||
eleven_key = os.getenv("ELEVENLABS_API_KEY", "")
|
||||
|
||||
if groq_key:
|
||||
check("GROQ_API_KEY (STT)", True, mask(groq_key))
|
||||
else:
|
||||
warn("GROQ_API_KEY", "not set — Groq STT unavailable")
|
||||
|
||||
if eleven_key:
|
||||
check("ELEVENLABS_API_KEY (TTS)", True, mask(eleven_key))
|
||||
else:
|
||||
warn("ELEVENLABS_API_KEY", "not set — ElevenLabs TTS unavailable")
|
||||
|
||||
return ok, token, groq_key, eleven_key
|
||||
|
||||
|
||||
def check_config(groq_key, eleven_key):
|
||||
"""Check hermes config.yaml."""
|
||||
section("Configuration")
|
||||
|
||||
config_path = HERMES_HOME / "config.yaml"
|
||||
if config_path.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
stt_provider = cfg.get("stt", {}).get("provider", "local")
|
||||
tts_provider = cfg.get("tts", {}).get("provider", "edge")
|
||||
check("STT provider", True, stt_provider)
|
||||
check("TTS provider", True, tts_provider)
|
||||
|
||||
if stt_provider == "groq" and not groq_key:
|
||||
warn("STT config says groq but GROQ_API_KEY is missing")
|
||||
if tts_provider == "elevenlabs" and not eleven_key:
|
||||
warn("TTS config says elevenlabs but ELEVENLABS_API_KEY is missing")
|
||||
except Exception as e:
|
||||
warn("config.yaml", f"parse error: {e}")
|
||||
else:
|
||||
warn("config.yaml", "not found — using defaults")
|
||||
|
||||
# Voice mode state
|
||||
voice_mode_path = HERMES_HOME / "gateway_voice_mode.json"
|
||||
if voice_mode_path.exists():
|
||||
try:
|
||||
import json
|
||||
modes = json.loads(voice_mode_path.read_text())
|
||||
off_count = sum(1 for v in modes.values() if v == "off")
|
||||
all_count = sum(1 for v in modes.values() if v == "all")
|
||||
check("Voice mode state", True, f"{all_count} on, {off_count} off, {len(modes)} total")
|
||||
except Exception:
|
||||
warn("Voice mode state", "parse error")
|
||||
else:
|
||||
check("Voice mode state", True, "no saved state (fresh)")
|
||||
|
||||
|
||||
def check_bot_permissions(token):
|
||||
"""Check bot permissions via Discord API. Returns True if all OK."""
|
||||
section("Bot Permissions")
|
||||
|
||||
if not token:
|
||||
warn("Bot permissions", "no token — skipping")
|
||||
return True
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
warn("Bot permissions", "requests not installed — skipping")
|
||||
return True
|
||||
|
||||
VOICE_PERMS = {
|
||||
"Priority Speaker": 8,
|
||||
"Stream": 9,
|
||||
"View Channel": 10,
|
||||
"Send Messages": 11,
|
||||
"Embed Links": 14,
|
||||
"Attach Files": 15,
|
||||
"Read Message History": 16,
|
||||
"Connect": 20,
|
||||
"Speak": 21,
|
||||
"Mute Members": 22,
|
||||
"Deafen Members": 23,
|
||||
"Move Members": 24,
|
||||
"Use VAD": 25,
|
||||
"Send Voice Messages": 46,
|
||||
}
|
||||
REQUIRED_PERMS = {"Connect", "Speak", "View Channel", "Send Messages"}
|
||||
ok = True
|
||||
|
||||
try:
|
||||
headers = {"Authorization": f"Bot {token}"}
|
||||
r = requests.get("https://discord.com/api/v10/users/@me", headers=headers, timeout=5)
|
||||
|
||||
if r.status_code == 401:
|
||||
check("Bot login", False, "invalid token (401)")
|
||||
return False
|
||||
if r.status_code != 200:
|
||||
check("Bot login", False, f"HTTP {r.status_code}")
|
||||
return False
|
||||
|
||||
bot = r.json()
|
||||
bot_name = bot.get("username", "?")
|
||||
check("Bot login", True, f"{bot_name[:3]}{'*' * (len(bot_name) - 3)}")
|
||||
|
||||
# Check guilds
|
||||
r2 = requests.get("https://discord.com/api/v10/users/@me/guilds", headers=headers, timeout=5)
|
||||
if r2.status_code != 200:
|
||||
warn("Guilds", f"HTTP {r2.status_code}")
|
||||
return ok
|
||||
|
||||
guilds = r2.json()
|
||||
check("Guilds", True, f"{len(guilds)} guild(s)")
|
||||
|
||||
for g in guilds[:5]:
|
||||
perms = int(g.get("permissions", 0))
|
||||
is_admin = bool(perms & (1 << 3))
|
||||
|
||||
if is_admin:
|
||||
print(f" {OK} {g['name']}: Administrator (all permissions)")
|
||||
continue
|
||||
|
||||
has = []
|
||||
missing = []
|
||||
for name, bit in sorted(VOICE_PERMS.items(), key=lambda x: x[1]):
|
||||
if perms & (1 << bit):
|
||||
has.append(name)
|
||||
elif name in REQUIRED_PERMS:
|
||||
missing.append(name)
|
||||
|
||||
if missing:
|
||||
print(f" {FAIL} {g['name']}: missing {', '.join(missing)}")
|
||||
ok = False
|
||||
else:
|
||||
print(f" {OK} {g['name']}: {', '.join(has)}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
warn("Bot permissions", "Discord API timeout")
|
||||
except requests.exceptions.ConnectionError:
|
||||
warn("Bot permissions", "cannot reach Discord API")
|
||||
except Exception as e:
|
||||
warn("Bot permissions", f"check failed: {e}")
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def main():
|
||||
print()
|
||||
print("\033[1m" + "=" * 50 + "\033[0m")
|
||||
print("\033[1m Discord Voice Doctor\033[0m")
|
||||
print("\033[1m" + "=" * 50 + "\033[0m")
|
||||
|
||||
all_ok = True
|
||||
|
||||
all_ok &= check_packages()
|
||||
all_ok &= check_system_tools()
|
||||
env_ok, token, groq_key, eleven_key = check_env_vars()
|
||||
all_ok &= env_ok
|
||||
check_config(groq_key, eleven_key)
|
||||
all_ok &= check_bot_permissions(token)
|
||||
|
||||
# Summary
|
||||
print()
|
||||
print("\033[1m" + "-" * 50 + "\033[0m")
|
||||
if all_ok:
|
||||
print(f" {OK} \033[92mAll checks passed — voice mode ready!\033[0m")
|
||||
else:
|
||||
print(f" {FAIL} \033[91mSome checks failed — fix issues above.\033[0m")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -195,7 +195,7 @@ class TestGetTextAuxiliaryClient:
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client()
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
# Returns a CodexAuxiliaryClient wrapper, not a raw OpenAI client
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
@@ -288,7 +288,7 @@ class TestVisionClientFallback:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_vision_auto_falls_back_to_custom_endpoint(self, monkeypatch):
|
||||
"""Custom endpoint is used as fallback in vision auto mode.
|
||||
@@ -371,7 +371,7 @@ class TestVisionClientFallback:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
class TestGetAuxiliaryProvider:
|
||||
@@ -489,7 +489,7 @@ class TestResolveForcedProvider:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_forced_codex(self, codex_auth_dir, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
@@ -497,7 +497,7 @@ class TestResolveForcedProvider:
|
||||
client, model = _resolve_forced_provider("codex")
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_forced_codex_no_token(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
|
||||
|
||||
@@ -252,3 +252,109 @@ async def test_discord_dms_ignore_mention_requirement(adapter, monkeypatch):
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "dm without mention"
|
||||
assert event.source.chat_type == "dm"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_enabled_by_default(adapter, monkeypatch):
|
||||
"""Auto-threading should be enabled by default (DISCORD_AUTO_THREAD defaults to 'true')."""
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
# Patch _auto_create_thread to return a fake thread
|
||||
fake_thread = FakeThread(channel_id=999, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_awaited_once()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "thread"
|
||||
assert event.source.thread_id == "999"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_can_be_disabled(adapter, monkeypatch):
|
||||
"""Setting auto_thread to false skips thread creation."""
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch):
|
||||
"""Messages in a thread the bot has participated in should not require @mention."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
# Simulate bot having previously participated in thread 456
|
||||
adapter._bot_participated_threads.add("456")
|
||||
|
||||
thread = FakeThread(channel_id=456, name="existing thread")
|
||||
message = make_message(channel=thread, content="follow-up without mention")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "follow-up without mention"
|
||||
assert event.source.chat_type == "thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_unknown_thread_still_requires_mention(adapter, monkeypatch):
|
||||
"""Messages in a thread the bot hasn't participated in should still require @mention."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
# Bot has NOT participated in thread 789
|
||||
thread = FakeThread(channel_id=789, name="some thread")
|
||||
message = make_message(channel=thread, content="hello from unknown thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch):
|
||||
"""Auto-created threads should be tracked for future mention-free replies."""
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
fake_thread = FakeThread(channel_id=555, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="start a thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "555" in adapter._bot_participated_threads
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeypatch):
|
||||
"""When the bot processes a message in a thread, it tracks participation."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
thread = FakeThread(channel_id=777, name="manually created thread")
|
||||
message = make_message(channel=thread, content="hello in thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "777" in adapter._bot_participated_threads
|
||||
|
||||
@@ -363,11 +363,37 @@ async def test_auto_thread_creates_thread_and_redirects(adapter, monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_disabled_by_default(adapter, monkeypatch):
|
||||
"""Without DISCORD_AUTO_THREAD, messages stay in the channel."""
|
||||
async def test_auto_thread_enabled_by_default_slash_commands(adapter, monkeypatch):
|
||||
"""Without DISCORD_AUTO_THREAD env var, auto-threading is enabled (default: true)."""
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
fake_thread = _FakeThreadChannel(channel_id=999, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def capture_handle(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = capture_handle
|
||||
|
||||
msg = _fake_message(_FakeTextChannel())
|
||||
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
adapter._auto_create_thread.assert_awaited_once()
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].source.chat_id == "999" # redirected to thread
|
||||
assert captured_events[0].source.chat_type == "thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_can_be_disabled(adapter, monkeypatch):
|
||||
"""Setting DISCORD_AUTO_THREAD=false keeps messages in the channel."""
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
captured_events = []
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _source(chat_id="123456", chat_type="dm"):
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||
adapter = StubAdapter()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def block_forever(_event):
|
||||
await release.wait()
|
||||
return None
|
||||
|
||||
adapter.set_message_handler(block_forever)
|
||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
session_key = build_session_key(event.source)
|
||||
assert session_key in adapter._active_sessions
|
||||
assert adapter._background_tasks
|
||||
|
||||
await adapter.cancel_background_tasks()
|
||||
|
||||
assert adapter._background_tasks == set()
|
||||
assert adapter._active_sessions == {}
|
||||
assert adapter._pending_messages == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._pending_messages = {"session": "pending text"}
|
||||
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
|
||||
adapter = StubAdapter()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def block_forever(_event):
|
||||
await release.wait()
|
||||
return None
|
||||
|
||||
adapter.set_message_handler(block_forever)
|
||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
disconnect_mock = AsyncMock()
|
||||
adapter.disconnect = disconnect_mock
|
||||
|
||||
session_key = build_session_key(event.source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {session_key: running_agent}
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
|
||||
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
||||
disconnect_mock.assert_awaited_once()
|
||||
assert runner.adapters == {}
|
||||
assert runner._running_agents == {}
|
||||
assert runner._pending_messages == {}
|
||||
assert runner._pending_approvals == {}
|
||||
assert runner._shutdown_event.is_set() is True
|
||||
@@ -0,0 +1,25 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_enrichment_uses_athabasca_upload_guidance_without_stale_r2_warning():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
||||
with patch(
|
||||
"tools.vision_tools.vision_analyze_tool",
|
||||
return_value='{"success": true, "analysis": "A painted serpent warrior."}',
|
||||
):
|
||||
enriched = await runner._enrich_message_with_vision(
|
||||
"caption",
|
||||
["/tmp/test.jpg"],
|
||||
)
|
||||
|
||||
assert "R2 not configured" not in enriched
|
||||
assert "Gateway media URL available for reference" not in enriched
|
||||
assert "POST /api/uploads" in enriched
|
||||
assert "Do not store the local cache path" in enriched
|
||||
assert "caption" in enriched
|
||||
@@ -11,7 +11,7 @@ import asyncio
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
@@ -50,11 +50,11 @@ class TestInterruptKeyConsistency:
|
||||
"""Ensure adapter interrupt methods are queried with session_key, not chat_id."""
|
||||
|
||||
def test_session_key_differs_from_chat_id_for_dm(self):
|
||||
"""Session key for a DM is NOT the same as chat_id."""
|
||||
"""Session key for a DM is namespaced and includes the DM chat_id."""
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
assert session_key != source.chat_id
|
||||
assert session_key == "agent:main:telegram:dm"
|
||||
assert session_key == "agent:main:telegram:dm:123456"
|
||||
|
||||
def test_session_key_differs_from_chat_id_for_group(self):
|
||||
"""Session key for a group chat includes prefix, unlike raw chat_id."""
|
||||
@@ -122,3 +122,29 @@ class TestInterruptKeyConsistency:
|
||||
|
||||
# Interrupt event was set
|
||||
assert adapter._active_sessions[session_key].is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_photo_followup_is_queued_without_interrupt(self):
|
||||
"""Photo follow-ups should queue behind the active run instead of interrupting it."""
|
||||
adapter = StubAdapter()
|
||||
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
|
||||
|
||||
source = _source("-1001234", "group")
|
||||
session_key = build_session_key(source)
|
||||
interrupt_event = asyncio.Event()
|
||||
adapter._active_sessions[session_key] = interrupt_event
|
||||
|
||||
event = MessageEvent(
|
||||
text="caption",
|
||||
source=source,
|
||||
message_type=MessageType.PHOTO,
|
||||
message_id="2",
|
||||
media_urls=["/tmp/photo-a.jpg"],
|
||||
media_types=["image/jpeg"],
|
||||
)
|
||||
await adapter.handle_message(event)
|
||||
|
||||
queued = adapter._pending_messages[session_key]
|
||||
assert queued is event
|
||||
assert queued.media_urls == ["/tmp/photo-a.jpg"]
|
||||
assert interrupt_event.is_set() is False
|
||||
|
||||
@@ -338,7 +338,7 @@ class TestSessionStoreRewriteTranscript:
|
||||
|
||||
class TestWhatsAppDMSessionKeyConsistency:
|
||||
"""Regression: all session-key construction must go through build_session_key
|
||||
so WhatsApp DMs include chat_id while other DMs do not."""
|
||||
so DMs are isolated by chat_id across platforms."""
|
||||
|
||||
@pytest.fixture()
|
||||
def store(self, tmp_path):
|
||||
@@ -369,15 +369,24 @@ class TestWhatsAppDMSessionKeyConsistency:
|
||||
)
|
||||
assert store._generate_session_key(source) == build_session_key(source)
|
||||
|
||||
def test_telegram_dm_omits_chat_id(self):
|
||||
"""Non-WhatsApp DMs should still omit chat_id (single owner DM)."""
|
||||
def test_telegram_dm_includes_chat_id(self):
|
||||
"""Non-WhatsApp DMs should also include chat_id to separate users."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="99",
|
||||
chat_type="dm",
|
||||
)
|
||||
key = build_session_key(source)
|
||||
assert key == "agent:main:telegram:dm"
|
||||
assert key == "agent:main:telegram:dm:99"
|
||||
|
||||
def test_distinct_dm_chat_ids_get_distinct_session_keys(self):
|
||||
"""Different DM chats must not collapse into one shared session."""
|
||||
first = SessionSource(platform=Platform.TELEGRAM, chat_id="99", chat_type="dm")
|
||||
second = SessionSource(platform=Platform.TELEGRAM, chat_id="100", chat_type="dm")
|
||||
|
||||
assert build_session_key(first) == "agent:main:telegram:dm:99"
|
||||
assert build_session_key(second) == "agent:main:telegram:dm:100"
|
||||
assert build_session_key(first) != build_session_key(second)
|
||||
|
||||
def test_discord_group_includes_chat_id(self):
|
||||
"""Group/channel keys include chat_type and chat_id."""
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionContext, SessionSource
|
||||
|
||||
|
||||
def test_set_session_env_includes_thread_id(monkeypatch):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
runner._set_session_env(context)
|
||||
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert os.getenv("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert os.getenv("HERMES_SESSION_CHAT_NAME") == "Group"
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") == "17585"
|
||||
|
||||
|
||||
def test_clear_session_env_removes_thread_id(monkeypatch):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
||||
monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "-1001")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_NAME", "Group")
|
||||
monkeypatch.setenv("HERMES_SESSION_THREAD_ID", "17585")
|
||||
|
||||
runner._clear_session_env()
|
||||
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") is None
|
||||
assert os.getenv("HERMES_SESSION_CHAT_ID") is None
|
||||
assert os.getenv("HERMES_SESSION_CHAT_NAME") is None
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") is None
|
||||
@@ -12,6 +12,7 @@ import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -351,6 +352,26 @@ class TestDocumentDownloadBlock:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMediaGroups:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_album_photo_burst_is_buffered_and_combined(self, adapter):
|
||||
first_photo = _make_photo(_make_file_obj(b"first"))
|
||||
second_photo = _make_photo(_make_file_obj(b"second"))
|
||||
|
||||
msg1 = _make_message(caption="two images", photo=[first_photo])
|
||||
msg2 = _make_message(photo=[second_photo])
|
||||
|
||||
with patch("gateway.platforms.telegram.cache_image_from_bytes", side_effect=["/tmp/burst-one.jpg", "/tmp/burst-two.jpg"]):
|
||||
await adapter._handle_media_message(_make_update(msg1), MagicMock())
|
||||
await adapter._handle_media_message(_make_update(msg2), MagicMock())
|
||||
assert adapter.handle_message.await_count == 0
|
||||
await asyncio.sleep(adapter.MEDIA_GROUP_WAIT_SECONDS + 0.05)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "two images"
|
||||
assert event.media_urls == ["/tmp/burst-one.jpg", "/tmp/burst-two.jpg"]
|
||||
assert len(event.media_types) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_photo_album_is_buffered_and_combined(self, adapter):
|
||||
first_photo = _make_photo(_make_file_obj(b"first"))
|
||||
@@ -537,6 +558,51 @@ class TestSendDocument:
|
||||
assert call_kwargs["reply_to_message_id"] == 50
|
||||
|
||||
|
||||
class TestTelegramPhotoBatching:
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_photo_batch_does_not_drop_newer_scheduled_task(self, adapter):
|
||||
old_task = MagicMock()
|
||||
new_task = MagicMock()
|
||||
batch_key = "session:photo-burst"
|
||||
adapter._pending_photo_batch_tasks[batch_key] = new_task
|
||||
adapter._pending_photo_batches[batch_key] = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.PHOTO,
|
||||
source=SimpleNamespace(channel_id="chat-1"),
|
||||
media_urls=["/tmp/a.jpg"],
|
||||
media_types=["image/jpeg"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("gateway.platforms.telegram.asyncio.current_task", return_value=old_task),
|
||||
patch("gateway.platforms.telegram.asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
await adapter._flush_photo_batch(batch_key)
|
||||
|
||||
assert adapter._pending_photo_batch_tasks[batch_key] is new_task
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cancels_pending_photo_batch_tasks(self, adapter):
|
||||
task = MagicMock()
|
||||
task.done.return_value = False
|
||||
adapter._pending_photo_batch_tasks["session:photo-burst"] = task
|
||||
adapter._pending_photo_batches["session:photo-burst"] = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.PHOTO,
|
||||
source=SimpleNamespace(channel_id="chat-1"),
|
||||
)
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.updater.stop = AsyncMock()
|
||||
adapter._app.stop = AsyncMock()
|
||||
adapter._app.shutdown = AsyncMock()
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
task.cancel.assert_called_once()
|
||||
assert adapter._pending_photo_batch_tasks == {}
|
||||
assert adapter._pending_photo_batches == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendVideo — outbound video delivery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
class _PendingAdapter:
|
||||
def __init__(self):
|
||||
self._pending_messages = {}
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
||||
runner.adapters = {Platform.TELEGRAM: _PendingAdapter()}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._voice_mode = {}
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_does_not_priority_interrupt_photo_followup():
|
||||
runner = _make_runner()
|
||||
source = SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm")
|
||||
session_key = build_session_key(source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents[session_key] = running_agent
|
||||
|
||||
event = MessageEvent(
|
||||
text="caption",
|
||||
message_type=MessageType.PHOTO,
|
||||
source=source,
|
||||
media_urls=["/tmp/photo-a.jpg"],
|
||||
media_types=["image/jpeg"],
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result is None
|
||||
running_agent.interrupt.assert_not_called()
|
||||
assert runner.adapters[Platform.TELEGRAM]._pending_messages[session_key] is event
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for the /voice command and auto voice reply in the gateway."""
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
@@ -206,9 +207,11 @@ class TestAutoVoiceReply:
|
||||
2. gateway _send_voice_reply: fires based on voice_mode setting
|
||||
|
||||
To prevent double audio, _send_voice_reply is skipped when voice input
|
||||
already triggered base adapter auto-TTS (skip_double = is_voice_input).
|
||||
Exception: Discord voice channel — both auto-TTS and Discord play_tts
|
||||
override skip, so the runner must handle it via play_in_voice_channel.
|
||||
already triggered base adapter auto-TTS.
|
||||
|
||||
For Discord voice channels, the base adapter now routes play_tts directly
|
||||
into VC playback, so the runner should still skip voice-input follow-ups to
|
||||
avoid double playback.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
@@ -292,14 +295,14 @@ class TestAutoVoiceReply:
|
||||
|
||||
# -- Discord VC exception: runner must handle --------------------------
|
||||
|
||||
def test_discord_vc_voice_input_runner_fires(self, runner):
|
||||
"""Discord VC + voice input: base play_tts skips (VC override),
|
||||
so runner must handle via play_in_voice_channel."""
|
||||
assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is True
|
||||
def test_discord_vc_voice_input_base_handles(self, runner):
|
||||
"""Discord VC + voice input: base adapter play_tts plays in VC,
|
||||
so runner skips to avoid double playback."""
|
||||
assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is False
|
||||
|
||||
def test_discord_vc_voice_only_runner_fires(self, runner):
|
||||
"""Discord VC + voice_only + voice: runner must handle."""
|
||||
assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is True
|
||||
def test_discord_vc_voice_only_base_handles(self, runner):
|
||||
"""Discord VC + voice_only + voice: base adapter handles."""
|
||||
assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is False
|
||||
|
||||
# -- Edge cases --------------------------------------------------------
|
||||
|
||||
@@ -422,17 +425,23 @@ class TestDiscordPlayTtsSkip:
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_skipped_when_in_vc(self):
|
||||
async def test_play_tts_plays_in_vc_when_connected(self):
|
||||
adapter = self._make_discord_adapter()
|
||||
# Simulate bot in voice channel for guild 111, text channel 123
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
mock_vc.is_playing.return_value = False
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 123
|
||||
|
||||
# Mock play_in_voice_channel to avoid actual ffmpeg call
|
||||
async def fake_play(gid, path):
|
||||
return True
|
||||
adapter.play_in_voice_channel = fake_play
|
||||
|
||||
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/test.ogg")
|
||||
# play_tts now plays in VC instead of being a no-op
|
||||
assert result.success is True
|
||||
# send_voice should NOT have been called (no client, would fail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_not_skipped_when_not_in_vc(self):
|
||||
@@ -728,6 +737,24 @@ class TestVoiceChannelCommands:
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "failed" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_missing_voice_dependencies(self, runner):
|
||||
"""Missing PyNaCl/davey should return a user-actionable install hint."""
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.name = "General"
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.join_voice_channel = AsyncMock(
|
||||
side_effect=RuntimeError("PyNaCl library needed in order to use voice")
|
||||
)
|
||||
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
|
||||
event = self._make_discord_event()
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
|
||||
assert "voice dependencies are missing" in result.lower()
|
||||
assert "hermes-agent[messaging]" in result
|
||||
|
||||
# -- _handle_voice_channel_leave --
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -2031,3 +2058,534 @@ class TestDisconnectVoiceCleanup:
|
||||
assert len(adapter._voice_receivers) == 0
|
||||
assert len(adapter._voice_listen_tasks) == 0
|
||||
assert len(adapter._voice_timeout_tasks) == 0
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Discord Voice Channel Flow Tests
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("nacl") is None,
|
||||
reason="PyNaCl not installed",
|
||||
)
|
||||
class TestVoiceReception:
|
||||
"""Audio reception: SSRC mapping, DAVE passthrough, buffer lifecycle."""
|
||||
|
||||
@staticmethod
|
||||
def _make_receiver(allowed_ids=None, members=None, dave=False, bot_id=9999):
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = [0] * 32
|
||||
vc._connection.dave_session = MagicMock() if dave else None
|
||||
vc._connection.ssrc = bot_id
|
||||
vc._connection.add_socket_listener = MagicMock()
|
||||
vc._connection.remove_socket_listener = MagicMock()
|
||||
vc._connection.hook = None
|
||||
vc.user = SimpleNamespace(id=bot_id)
|
||||
vc.channel = MagicMock()
|
||||
vc.channel.members = members or []
|
||||
receiver = VoiceReceiver(vc, allowed_user_ids=allowed_ids)
|
||||
return receiver
|
||||
|
||||
@staticmethod
|
||||
def _fill_buffer(receiver, ssrc, duration_s=1.0, age_s=3.0):
|
||||
"""Add PCM data to buffer. 48kHz stereo 16-bit = 192000 bytes/sec."""
|
||||
size = int(192000 * duration_s)
|
||||
receiver._buffers[ssrc] = bytearray(b"\x00" * size)
|
||||
receiver._last_packet_time[ssrc] = time.monotonic() - age_s
|
||||
|
||||
# -- Known SSRC (normal flow) --
|
||||
|
||||
def test_known_ssrc_returns_completed(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.map_ssrc(100, 42)
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
assert len(receiver._buffers[100]) == 0 # cleared
|
||||
|
||||
def test_known_ssrc_short_buffer_ignored(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.map_ssrc(100, 42)
|
||||
self._fill_buffer(receiver, 100, duration_s=0.1) # too short
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_known_ssrc_recent_audio_waits(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.map_ssrc(100, 42)
|
||||
self._fill_buffer(receiver, 100, age_s=0.0) # just arrived
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
# -- Unknown SSRC + DAVE passthrough --
|
||||
|
||||
def test_unknown_ssrc_no_automap_no_completed(self):
|
||||
"""Unknown SSRC, no members to infer — buffer cleared, not returned."""
|
||||
receiver = self._make_receiver(dave=True, members=[])
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
assert len(receiver._buffers[100]) == 0
|
||||
|
||||
def test_unknown_ssrc_late_speaking_event(self):
|
||||
"""Audio buffered before SPEAKING → SPEAKING maps → next check returns it."""
|
||||
receiver = self._make_receiver(dave=True)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100, age_s=0.0) # still receiving
|
||||
# No user yet
|
||||
assert receiver.check_silence() == []
|
||||
# SPEAKING event arrives
|
||||
receiver.map_ssrc(100, 42)
|
||||
# Silence kicks in
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
# -- SSRC auto-mapping --
|
||||
|
||||
def test_automap_single_allowed_user(self):
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = self._make_receiver(allowed_ids={"42"}, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
assert receiver._ssrc_to_user[100] == 42
|
||||
|
||||
def test_automap_multiple_allowed_users_no_map(self):
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
SimpleNamespace(id=43, name="Bob"),
|
||||
]
|
||||
receiver = self._make_receiver(allowed_ids={"42", "43"}, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_automap_no_allowlist_single_member(self):
|
||||
"""No allowed_user_ids → sole non-bot member inferred."""
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = self._make_receiver(allowed_ids=None, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
def test_automap_unallowed_user_rejected(self):
|
||||
"""User in channel but not in allowed list — not mapped."""
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = self._make_receiver(allowed_ids={"99"}, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_automap_only_bot_in_channel(self):
|
||||
"""Only bot in channel — no one to map to."""
|
||||
members = [SimpleNamespace(id=9999, name="Bot")]
|
||||
receiver = self._make_receiver(allowed_ids=None, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_automap_persists_across_calls(self):
|
||||
"""Auto-mapped SSRC stays mapped for subsequent checks."""
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = self._make_receiver(allowed_ids={"42"}, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
receiver.check_silence()
|
||||
assert receiver._ssrc_to_user[100] == 42
|
||||
# Second utterance — should use cached mapping
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
# -- Stale buffer cleanup --
|
||||
|
||||
def test_stale_unknown_buffer_discarded(self):
|
||||
"""Buffer with no user and very old timestamp is discarded."""
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver._buffers[200] = bytearray(b"\x00" * 100)
|
||||
receiver._last_packet_time[200] = time.monotonic() - 10.0
|
||||
receiver.check_silence()
|
||||
assert 200 not in receiver._buffers
|
||||
|
||||
# -- Pause / resume (echo prevention) --
|
||||
|
||||
def test_paused_receiver_ignores_packets(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.pause()
|
||||
receiver._on_packet(b"\x00" * 100)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_resumed_receiver_accepts_packets(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.pause()
|
||||
receiver.resume()
|
||||
assert receiver._paused is False
|
||||
|
||||
# -- _on_packet DAVE passthrough behavior --
|
||||
|
||||
def _make_receiver_with_nacl(self, dave_session=None, mapped_ssrcs=None):
|
||||
"""Create a receiver that can process _on_packet with mocked NaCl + Opus."""
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = [0] * 32
|
||||
vc._connection.dave_session = dave_session
|
||||
vc._connection.ssrc = 9999
|
||||
vc._connection.add_socket_listener = MagicMock()
|
||||
vc._connection.remove_socket_listener = MagicMock()
|
||||
vc._connection.hook = None
|
||||
vc.user = SimpleNamespace(id=9999)
|
||||
vc.channel = MagicMock()
|
||||
vc.channel.members = []
|
||||
receiver = VoiceReceiver(vc)
|
||||
receiver.start()
|
||||
# Pre-map SSRCs if provided
|
||||
if mapped_ssrcs:
|
||||
for ssrc, uid in mapped_ssrcs.items():
|
||||
receiver.map_ssrc(ssrc, uid)
|
||||
return receiver
|
||||
|
||||
@staticmethod
|
||||
def _build_rtp_packet(ssrc=100, seq=1, timestamp=960):
|
||||
"""Build a minimal valid RTP packet for _on_packet.
|
||||
|
||||
We need: RTP header (12 bytes) + encrypted payload + 4-byte nonce.
|
||||
NaCl decrypt is mocked so payload content doesn't matter.
|
||||
"""
|
||||
import struct
|
||||
# RTP header: version=2, payload_type=0x78, no extension, no CSRC
|
||||
header = struct.pack(">BBHII", 0x80, 0x78, seq, timestamp, ssrc)
|
||||
# Fake encrypted payload (NaCl will be mocked) + 4 byte nonce
|
||||
payload = b"\x00" * 20 + b"\x00\x00\x00\x01"
|
||||
return header + payload
|
||||
|
||||
def _inject_mock_decoder(self, receiver, ssrc):
|
||||
"""Pre-inject a mock Opus decoder for the given SSRC."""
|
||||
mock_decoder = MagicMock()
|
||||
mock_decoder.decode.return_value = b"\x00" * 3840
|
||||
receiver._decoders[ssrc] = mock_decoder
|
||||
return mock_decoder
|
||||
|
||||
def test_on_packet_dave_known_user_decrypt_ok(self):
|
||||
"""Known SSRC + DAVE decrypt success → audio buffered."""
|
||||
dave = MagicMock()
|
||||
dave.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver = self._make_receiver_with_nacl(
|
||||
dave_session=dave, mapped_ssrcs={100: 42}
|
||||
)
|
||||
self._inject_mock_decoder(receiver, 100)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
dave.decrypt.assert_called_once()
|
||||
|
||||
def test_on_packet_dave_unknown_ssrc_passthrough(self):
|
||||
"""Unknown SSRC + DAVE → skip DAVE, attempt Opus decode (passthrough)."""
|
||||
dave = MagicMock()
|
||||
receiver = self._make_receiver_with_nacl(dave_session=dave)
|
||||
self._inject_mock_decoder(receiver, 100)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
|
||||
dave.decrypt.assert_not_called()
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_on_packet_dave_unencrypted_error_passthrough(self):
|
||||
"""DAVE decrypt 'Unencrypted' error → use data as-is, don't drop."""
|
||||
dave = MagicMock()
|
||||
dave.decrypt.side_effect = Exception(
|
||||
"Failed to decrypt: DecryptionFailed(UnencryptedWhenPassthroughDisabled)"
|
||||
)
|
||||
receiver = self._make_receiver_with_nacl(
|
||||
dave_session=dave, mapped_ssrcs={100: 42}
|
||||
)
|
||||
self._inject_mock_decoder(receiver, 100)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_on_packet_dave_other_error_drops(self):
|
||||
"""DAVE decrypt non-Unencrypted error → packet dropped."""
|
||||
dave = MagicMock()
|
||||
dave.decrypt.side_effect = Exception("KeyRotationFailed")
|
||||
receiver = self._make_receiver_with_nacl(
|
||||
dave_session=dave, mapped_ssrcs={100: 42}
|
||||
)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
def test_on_packet_no_dave_direct_decode(self):
|
||||
"""No DAVE session → decode directly."""
|
||||
receiver = self._make_receiver_with_nacl(dave_session=None)
|
||||
self._inject_mock_decoder(receiver, 100)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_on_packet_bot_own_ssrc_ignored(self):
|
||||
"""Bot's own SSRC → dropped (echo prevention)."""
|
||||
receiver = self._make_receiver_with_nacl()
|
||||
with patch("nacl.secret.Aead"):
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=9999))
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_on_packet_multiple_ssrcs_separate_buffers(self):
|
||||
"""Different SSRCs → separate buffers."""
|
||||
receiver = self._make_receiver_with_nacl(dave_session=None)
|
||||
self._inject_mock_decoder(receiver, 100)
|
||||
self._inject_mock_decoder(receiver, 200)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=200))
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert 200 in receiver._buffers
|
||||
|
||||
|
||||
class TestVoiceTTSPlayback:
|
||||
"""TTS playback: play_tts in VC, dedup, fallback."""
|
||||
|
||||
@staticmethod
|
||||
def _make_discord_adapter():
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_receivers = {}
|
||||
return adapter
|
||||
|
||||
# -- play_tts behavior --
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_plays_in_vc(self):
|
||||
"""play_tts calls play_in_voice_channel when bot is in VC."""
|
||||
adapter = self._make_discord_adapter()
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 123
|
||||
|
||||
played = []
|
||||
async def fake_play(gid, path):
|
||||
played.append((gid, path))
|
||||
return True
|
||||
adapter.play_in_voice_channel = fake_play
|
||||
|
||||
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/tts.ogg")
|
||||
assert result.success is True
|
||||
assert played == [(111, "/tmp/tts.ogg")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_fallback_when_not_in_vc(self):
|
||||
"""play_tts sends as file attachment when bot is not in VC."""
|
||||
adapter = self._make_discord_adapter()
|
||||
from gateway.platforms.base import SendResult
|
||||
adapter.send_voice = AsyncMock(return_value=SendResult(success=False, error="no client"))
|
||||
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/tts.ogg")
|
||||
assert result.success is False
|
||||
adapter.send_voice.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_wrong_channel_no_match(self):
|
||||
"""play_tts doesn't match if chat_id is for a different channel."""
|
||||
adapter = self._make_discord_adapter()
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 123
|
||||
|
||||
from gateway.platforms.base import SendResult
|
||||
adapter.send_voice = AsyncMock(return_value=SendResult(success=True))
|
||||
# Different chat_id — shouldn't match VC
|
||||
result = await adapter.play_tts(chat_id="999", audio_path="/tmp/tts.ogg")
|
||||
adapter.send_voice.assert_called_once()
|
||||
|
||||
# -- Runner dedup --
|
||||
|
||||
@staticmethod
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._voice_mode = {}
|
||||
runner.adapters = {}
|
||||
return runner
|
||||
|
||||
def _call_should_reply(self, runner, voice_mode, msg_type, response="Hello", agent_msgs=None):
|
||||
from gateway.platforms.base import MessageType, MessageEvent, SessionSource
|
||||
from gateway.config import Platform
|
||||
runner._voice_mode["ch1"] = voice_mode
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD, chat_id="ch1",
|
||||
user_id="1", user_name="test", chat_type="channel",
|
||||
)
|
||||
event = MessageEvent(source=source, text="test", message_type=msg_type)
|
||||
return runner._should_send_voice_reply(event, response, agent_msgs or [])
|
||||
|
||||
def test_voice_input_runner_skips(self):
|
||||
"""Voice input: runner skips — base adapter handles via play_tts."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "all", MessageType.VOICE) is False
|
||||
|
||||
def test_text_input_voice_all_runner_fires(self):
|
||||
"""Text input + voice_mode=all: runner generates TTS."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "all", MessageType.TEXT) is True
|
||||
|
||||
def test_text_input_voice_off_no_tts(self):
|
||||
"""Text input + voice_mode=off: no TTS."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "off", MessageType.TEXT) is False
|
||||
|
||||
def test_text_input_voice_only_no_tts(self):
|
||||
"""Text input + voice_mode=voice_only: no TTS for text."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "voice_only", MessageType.TEXT) is False
|
||||
|
||||
def test_error_response_no_tts(self):
|
||||
"""Error response: no TTS regardless of voice_mode."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "all", MessageType.TEXT, response="Error: boom") is False
|
||||
|
||||
def test_empty_response_no_tts(self):
|
||||
"""Empty response: no TTS."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "all", MessageType.TEXT, response="") is False
|
||||
|
||||
def test_agent_tts_tool_dedup(self):
|
||||
"""Agent already called text_to_speech tool: runner skips."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
agent_msgs = [{"role": "assistant", "tool_calls": [
|
||||
{"id": "1", "type": "function", "function": {"name": "text_to_speech", "arguments": "{}"}}
|
||||
]}]
|
||||
assert self._call_should_reply(runner, "all", MessageType.TEXT, agent_msgs=agent_msgs) is False
|
||||
|
||||
|
||||
class TestUDPKeepalive:
|
||||
"""UDP keepalive prevents Discord from dropping the voice session."""
|
||||
|
||||
def test_keepalive_interval_is_reasonable(self):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
interval = DiscordAdapter._KEEPALIVE_INTERVAL
|
||||
assert 5 <= interval <= 30, f"Keepalive interval {interval}s should be between 5-30s"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keepalive_sends_silence_frame(self):
|
||||
"""Listen loop sends silence frame via send_packet after interval."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake"
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
|
||||
# Mock VC and receiver
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
mock_conn = MagicMock()
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
mock_vc._connection = mock_conn
|
||||
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
mock_receiver_vc = MagicMock()
|
||||
mock_receiver_vc._connection.secret_key = [0] * 32
|
||||
mock_receiver_vc._connection.dave_session = None
|
||||
mock_receiver_vc._connection.ssrc = 9999
|
||||
mock_receiver_vc._connection.add_socket_listener = MagicMock()
|
||||
mock_receiver_vc._connection.remove_socket_listener = MagicMock()
|
||||
mock_receiver_vc._connection.hook = None
|
||||
receiver = VoiceReceiver(mock_receiver_vc)
|
||||
receiver.start()
|
||||
adapter._voice_receivers[111] = receiver
|
||||
|
||||
# Set keepalive interval very short for test
|
||||
original_interval = DiscordAdapter._KEEPALIVE_INTERVAL
|
||||
DiscordAdapter._KEEPALIVE_INTERVAL = 0.1
|
||||
|
||||
try:
|
||||
# Run listen loop briefly
|
||||
import asyncio
|
||||
loop_task = asyncio.create_task(adapter._voice_listen_loop(111))
|
||||
await asyncio.sleep(0.3)
|
||||
receiver._running = False # stop loop
|
||||
await asyncio.sleep(0.1)
|
||||
loop_task.cancel()
|
||||
try:
|
||||
await loop_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# send_packet should have been called with silence frame
|
||||
mock_conn.send_packet.assert_called_with(b'\xf8\xff\xfe')
|
||||
finally:
|
||||
DiscordAdapter._KEEPALIVE_INTERVAL = original_interval
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
|
||||
def test_user_env_overrides_stale_shell_values(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
home.mkdir()
|
||||
env_file = home / ".env"
|
||||
env_file.write_text("OPENAI_BASE_URL=https://new.example/v1\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://old.example/v1")
|
||||
|
||||
loaded = load_hermes_dotenv(hermes_home=home)
|
||||
|
||||
assert loaded == [env_file]
|
||||
assert os.getenv("OPENAI_BASE_URL") == "https://new.example/v1"
|
||||
|
||||
|
||||
def test_project_env_overrides_stale_shell_values_when_user_env_missing(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
project_env = tmp_path / ".env"
|
||||
project_env.write_text("OPENAI_BASE_URL=https://project.example/v1\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://old.example/v1")
|
||||
|
||||
loaded = load_hermes_dotenv(hermes_home=home, project_env=project_env)
|
||||
|
||||
assert loaded == [project_env]
|
||||
assert os.getenv("OPENAI_BASE_URL") == "https://project.example/v1"
|
||||
|
||||
|
||||
def test_user_env_takes_precedence_over_project_env(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
home.mkdir()
|
||||
user_env = home / ".env"
|
||||
project_env = tmp_path / ".env"
|
||||
user_env.write_text("OPENAI_BASE_URL=https://user.example/v1\n", encoding="utf-8")
|
||||
project_env.write_text("OPENAI_BASE_URL=https://project.example/v1\nOPENAI_API_KEY=project-key\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://old.example/v1")
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
|
||||
loaded = load_hermes_dotenv(hermes_home=home, project_env=project_env)
|
||||
|
||||
assert loaded == [user_env, project_env]
|
||||
assert os.getenv("OPENAI_BASE_URL") == "https://user.example/v1"
|
||||
assert os.getenv("OPENAI_API_KEY") == "project-key"
|
||||
|
||||
|
||||
def test_main_import_applies_user_env_over_shell_values(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
home.mkdir()
|
||||
(home / ".env").write_text(
|
||||
"OPENAI_BASE_URL=https://new.example/v1\nHERMES_INFERENCE_PROVIDER=custom\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://old.example/v1")
|
||||
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openrouter")
|
||||
|
||||
sys.modules.pop("hermes_cli.main", None)
|
||||
importlib.import_module("hermes_cli.main")
|
||||
|
||||
assert os.getenv("OPENAI_BASE_URL") == "https://new.example/v1"
|
||||
assert os.getenv("HERMES_INFERENCE_PROVIDER") == "custom"
|
||||
@@ -0,0 +1,64 @@
|
||||
import sys
|
||||
|
||||
|
||||
def test_sessions_delete_accepts_unique_id_prefix(monkeypatch, capsys):
|
||||
import hermes_cli.main as main_mod
|
||||
import hermes_state
|
||||
|
||||
captured = {}
|
||||
|
||||
class FakeDB:
|
||||
def resolve_session_id(self, session_id):
|
||||
captured["resolved_from"] = session_id
|
||||
return "20260315_092437_c9a6ff"
|
||||
|
||||
def delete_session(self, session_id):
|
||||
captured["deleted"] = session_id
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
captured["closed"] = True
|
||||
|
||||
monkeypatch.setattr(hermes_state, "SessionDB", lambda: FakeDB())
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
["hermes", "sessions", "delete", "20260315_092437_c9a6", "--yes"],
|
||||
)
|
||||
|
||||
main_mod.main()
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert captured == {
|
||||
"resolved_from": "20260315_092437_c9a6",
|
||||
"deleted": "20260315_092437_c9a6ff",
|
||||
"closed": True,
|
||||
}
|
||||
assert "Deleted session '20260315_092437_c9a6ff'." in output
|
||||
|
||||
|
||||
def test_sessions_delete_reports_not_found_when_prefix_is_unknown(monkeypatch, capsys):
|
||||
import hermes_cli.main as main_mod
|
||||
import hermes_state
|
||||
|
||||
class FakeDB:
|
||||
def resolve_session_id(self, session_id):
|
||||
return None
|
||||
|
||||
def delete_session(self, session_id):
|
||||
raise AssertionError("delete_session should not be called when resolution fails")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(hermes_state, "SessionDB", lambda: FakeDB())
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
["hermes", "sessions", "delete", "missing-prefix", "--yes"],
|
||||
)
|
||||
|
||||
main_mod.main()
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Session 'missing-prefix' not found." in output
|
||||
@@ -1,6 +1,13 @@
|
||||
"""Tests for hermes_cli.tools_config platform tool persistence."""
|
||||
|
||||
from hermes_cli.tools_config import _get_platform_tools, _platform_toolset_summary, _toolset_has_keys
|
||||
from unittest.mock import patch
|
||||
|
||||
from hermes_cli.tools_config import (
|
||||
_get_platform_tools,
|
||||
_platform_toolset_summary,
|
||||
_save_platform_tools,
|
||||
_toolset_has_keys,
|
||||
)
|
||||
|
||||
|
||||
def test_get_platform_tools_uses_default_when_platform_not_configured():
|
||||
@@ -31,7 +38,7 @@ def test_platform_toolset_summary_uses_explicit_platform_list():
|
||||
def test_toolset_has_keys_for_vision_accepts_codex_auth(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "auth.json").write_text(
|
||||
'{"active_provider":"openai-codex","providers":{"openai-codex":{"tokens":{"access_token":"codex-access-token","refresh_token":"codex-refresh-token"}}}}'
|
||||
'{"active_provider":"openai-codex","providers":{"openai-codex":{"tokens":{"access_token": "codex-...oken","refresh_token": "codex-...oken"}}}}'
|
||||
)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
@@ -40,3 +47,56 @@ def test_toolset_has_keys_for_vision_accepts_codex_auth(tmp_path, monkeypatch):
|
||||
monkeypatch.delenv("CONTEXT_VISION_PROVIDER", raising=False)
|
||||
|
||||
assert _toolset_has_keys("vision") is True
|
||||
|
||||
|
||||
def test_save_platform_tools_preserves_mcp_server_names():
|
||||
"""Ensure MCP server names are preserved when saving platform tools.
|
||||
|
||||
Regression test for https://github.com/NousResearch/hermes-agent/issues/1247
|
||||
"""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"cli": ["web", "terminal", "time", "github", "custom-mcp-server"]
|
||||
}
|
||||
}
|
||||
|
||||
new_selection = {"web", "browser"}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "cli", new_selection)
|
||||
|
||||
saved_toolsets = config["platform_toolsets"]["cli"]
|
||||
|
||||
assert "time" in saved_toolsets
|
||||
assert "github" in saved_toolsets
|
||||
assert "custom-mcp-server" in saved_toolsets
|
||||
assert "web" in saved_toolsets
|
||||
assert "browser" in saved_toolsets
|
||||
assert "terminal" not in saved_toolsets
|
||||
|
||||
|
||||
def test_save_platform_tools_handles_empty_existing_config():
|
||||
"""Saving platform tools works when no existing config exists."""
|
||||
config = {}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "telegram", {"web", "terminal"})
|
||||
|
||||
saved_toolsets = config["platform_toolsets"]["telegram"]
|
||||
assert "web" in saved_toolsets
|
||||
assert "terminal" in saved_toolsets
|
||||
|
||||
|
||||
def test_save_platform_tools_handles_invalid_existing_config():
|
||||
"""Saving platform tools works when existing config is not a list."""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"cli": "invalid-string-value"
|
||||
}
|
||||
}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "cli", {"web"})
|
||||
|
||||
saved_toolsets = config["platform_toolsets"]["cli"]
|
||||
assert "web" in saved_toolsets
|
||||
|
||||
@@ -0,0 +1,611 @@
|
||||
"""Integration tests for Discord voice channel audio flow.
|
||||
|
||||
Uses real NaCl encryption and Opus codec (no mocks for crypto/codec).
|
||||
Does NOT require a Discord connection — tests the VoiceReceiver
|
||||
packet processing pipeline end-to-end.
|
||||
|
||||
Requires: PyNaCl>=1.5.0, discord.py[voice] (opus codec)
|
||||
"""
|
||||
|
||||
import struct
|
||||
import time
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
# Skip entire module if voice deps are missing
|
||||
pytest.importorskip("nacl.secret", reason="PyNaCl required for voice integration tests")
|
||||
discord = pytest.importorskip("discord", reason="discord.py required for voice integration tests")
|
||||
|
||||
import nacl.secret
|
||||
|
||||
try:
|
||||
if not discord.opus.is_loaded():
|
||||
import ctypes.util
|
||||
opus_path = ctypes.util.find_library("opus")
|
||||
if not opus_path:
|
||||
import sys
|
||||
for p in ("/opt/homebrew/lib/libopus.dylib", "/usr/local/lib/libopus.dylib"):
|
||||
import os
|
||||
if os.path.isfile(p):
|
||||
opus_path = p
|
||||
break
|
||||
if opus_path:
|
||||
discord.opus.load_opus(opus_path)
|
||||
OPUS_AVAILABLE = discord.opus.is_loaded()
|
||||
except Exception:
|
||||
OPUS_AVAILABLE = False
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_secret_key():
|
||||
"""Generate a random 32-byte key."""
|
||||
import os
|
||||
return os.urandom(32)
|
||||
|
||||
|
||||
def _build_encrypted_rtp_packet(secret_key, opus_payload, ssrc=100, seq=1, timestamp=960):
|
||||
"""Build a real NaCl-encrypted RTP packet matching Discord's format.
|
||||
|
||||
Format: RTP header (12 bytes) + encrypted(opus) + 4-byte nonce
|
||||
Encryption: aead_xchacha20_poly1305 with RTP header as AAD.
|
||||
"""
|
||||
# RTP header: version=2, payload_type=0x78, no extension, no CSRC
|
||||
header = struct.pack(">BBHII", 0x80, 0x78, seq, timestamp, ssrc)
|
||||
|
||||
# Encrypt with NaCl AEAD
|
||||
box = nacl.secret.Aead(secret_key)
|
||||
nonce_counter = struct.pack(">I", seq) # 4-byte counter as nonce seed
|
||||
# Full 24-byte nonce: counter in first 4 bytes, rest zeros
|
||||
full_nonce = nonce_counter + b'\x00' * 20
|
||||
|
||||
enc_msg = box.encrypt(opus_payload, header, full_nonce)
|
||||
ciphertext = enc_msg.ciphertext # without nonce prefix
|
||||
|
||||
# Discord format: header + ciphertext + 4-byte nonce
|
||||
return header + ciphertext + nonce_counter
|
||||
|
||||
|
||||
def _make_voice_receiver(secret_key, dave_session=None, bot_ssrc=9999,
|
||||
allowed_user_ids=None, members=None):
|
||||
"""Create a VoiceReceiver with real secret key."""
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = list(secret_key)
|
||||
vc._connection.dave_session = dave_session
|
||||
vc._connection.ssrc = bot_ssrc
|
||||
vc._connection.add_socket_listener = MagicMock()
|
||||
vc._connection.remove_socket_listener = MagicMock()
|
||||
vc._connection.hook = None
|
||||
vc.user = SimpleNamespace(id=bot_ssrc)
|
||||
vc.channel = MagicMock()
|
||||
vc.channel.members = members or []
|
||||
receiver = VoiceReceiver(vc, allowed_user_ids=allowed_user_ids)
|
||||
receiver.start()
|
||||
return receiver
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRealNaClDecrypt:
|
||||
"""End-to-end: real NaCl encrypt → _on_packet decrypt → buffer."""
|
||||
|
||||
def test_valid_encrypted_packet_buffered(self):
|
||||
"""Real NaCl encrypted packet → decrypted → buffered."""
|
||||
key = _make_secret_key()
|
||||
opus_silence = b'\xf8\xff\xfe'
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, opus_silence, ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_wrong_key_packet_dropped(self):
|
||||
"""Packet encrypted with wrong key → NaCl fails → not buffered."""
|
||||
real_key = _make_secret_key()
|
||||
wrong_key = _make_secret_key()
|
||||
opus_silence = b'\xf8\xff\xfe'
|
||||
receiver = _make_voice_receiver(real_key)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(wrong_key, opus_silence, ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
def test_bot_ssrc_ignored(self):
|
||||
"""Packet from bot's own SSRC → ignored."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key, bot_ssrc=9999)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=9999)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_multiple_packets_accumulate(self):
|
||||
"""Multiple valid packets → buffer grows."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
for seq in range(1, 6):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
buf_size = len(receiver._buffers[100])
|
||||
assert buf_size > 0, "Multiple packets should accumulate in buffer"
|
||||
|
||||
def test_different_ssrcs_separate_buffers(self):
|
||||
"""Packets from different SSRCs → separate buffers."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
for ssrc in [100, 200, 300]:
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=ssrc)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers) == 3
|
||||
for ssrc in [100, 200, 300]:
|
||||
assert ssrc in receiver._buffers
|
||||
|
||||
|
||||
class TestRealNaClWithDAVE:
|
||||
"""NaCl decrypt + DAVE passthrough scenarios with real crypto."""
|
||||
|
||||
def test_dave_unknown_ssrc_passthrough(self):
|
||||
"""DAVE enabled but SSRC unknown → skip DAVE, buffer audio."""
|
||||
key = _make_secret_key()
|
||||
dave = MagicMock() # DAVE session present but SSRC not mapped
|
||||
receiver = _make_voice_receiver(key, dave_session=dave)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
# DAVE decrypt not called (SSRC unknown)
|
||||
dave.decrypt.assert_not_called()
|
||||
# Audio still buffered via passthrough
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_dave_unencrypted_error_passthrough(self):
|
||||
"""DAVE raises 'Unencrypted' → use NaCl-decrypted data as-is."""
|
||||
key = _make_secret_key()
|
||||
dave = MagicMock()
|
||||
dave.decrypt.side_effect = Exception(
|
||||
"DecryptionFailed(UnencryptedWhenPassthroughDisabled)"
|
||||
)
|
||||
receiver = _make_voice_receiver(key, dave_session=dave)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
# DAVE was called but failed → passthrough
|
||||
dave.decrypt.assert_called_once()
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_dave_real_error_drops(self):
|
||||
"""DAVE raises non-Unencrypted error → packet dropped."""
|
||||
key = _make_secret_key()
|
||||
dave = MagicMock()
|
||||
dave.decrypt.side_effect = Exception("KeyRotationFailed")
|
||||
receiver = _make_voice_receiver(key, dave_session=dave)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
|
||||
class TestFullVoiceFlow:
|
||||
"""End-to-end: encrypt → receive → buffer → silence detect → complete."""
|
||||
|
||||
def test_single_utterance_flow(self):
|
||||
"""Encrypt packets → buffer → silence → check_silence returns utterance."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
# Send enough packets to exceed MIN_SPEECH_DURATION (0.5s)
|
||||
# At 48kHz stereo 16-bit, each Opus silence frame decodes to ~3840 bytes
|
||||
# Need 96000 bytes = ~25 frames
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
# Simulate silence by setting last_packet_time in the past
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
user_id, pcm_data = completed[0]
|
||||
assert user_id == 42
|
||||
assert len(pcm_data) > 0
|
||||
|
||||
def test_utterance_with_ssrc_automap(self):
|
||||
"""No SPEAKING event → auto-map sole allowed user → utterance processed."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids={"42"}, members=members
|
||||
)
|
||||
# No map_ssrc call — simulating missing SPEAKING event
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42 # auto-mapped to sole allowed user
|
||||
|
||||
def test_pause_blocks_during_playback(self):
|
||||
"""Pause receiver → packets ignored → resume → packets accepted."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
# Pause (echo prevention during TTS playback)
|
||||
receiver.pause()
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
# Resume
|
||||
receiver.resume()
|
||||
receiver._on_packet(packet)
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_corrupted_packet_ignored(self):
|
||||
"""Corrupted/truncated packet → silently ignored."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
# Too short
|
||||
receiver._on_packet(b"\x00" * 5)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
# Wrong RTP version
|
||||
bad_header = struct.pack(">BBHII", 0x00, 0x78, 1, 960, 100)
|
||||
receiver._on_packet(bad_header + b"\x00" * 20)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
# Wrong payload type
|
||||
bad_pt = struct.pack(">BBHII", 0x80, 0x00, 1, 960, 100)
|
||||
receiver._on_packet(bad_pt + b"\x00" * 20)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_stop_cleans_everything(self):
|
||||
"""stop() clears all state cleanly."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
receiver.stop()
|
||||
assert receiver._running is False
|
||||
assert len(receiver._buffers) == 0
|
||||
assert len(receiver._ssrc_to_user) == 0
|
||||
assert len(receiver._decoders) == 0
|
||||
|
||||
|
||||
class TestSPEAKINGHook:
|
||||
"""SPEAKING event hook correctly maps SSRC to user_id."""
|
||||
|
||||
def test_speaking_hook_installed(self):
|
||||
"""start() installs speaking hook on connection."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
conn = receiver._vc._connection
|
||||
# hook should be set (wrapped)
|
||||
assert conn.hook is not None
|
||||
|
||||
def test_map_ssrc_via_speaking(self):
|
||||
"""SPEAKING op 5 event maps SSRC to user_id."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(500, 12345)
|
||||
assert receiver._ssrc_to_user[500] == 12345
|
||||
|
||||
def test_map_ssrc_overwrites(self):
|
||||
"""New SPEAKING event for same SSRC overwrites old mapping."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(500, 111)
|
||||
receiver.map_ssrc(500, 222)
|
||||
assert receiver._ssrc_to_user[500] == 222
|
||||
|
||||
def test_speaking_mapped_audio_processed(self):
|
||||
"""After SSRC is mapped, audio from that SSRC gets correct user_id."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
|
||||
class TestAuthFiltering:
|
||||
"""Only allowed users' audio should be processed."""
|
||||
|
||||
def test_allowed_user_audio_processed(self):
|
||||
"""Allowed user's utterance is returned by check_silence."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids={"42"}, members=members,
|
||||
)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
def test_automap_rejects_unallowed_user(self):
|
||||
"""Auto-map refuses to map SSRC to user not in allowed list."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids={"99"}, # Alice not allowed
|
||||
members=members,
|
||||
)
|
||||
# No map_ssrc — SSRC unknown, auto-map should reject
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_empty_allowlist_allows_all(self):
|
||||
"""Empty allowed_user_ids means no restriction."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids=None, members=members,
|
||||
)
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
# Auto-mapped to sole non-bot member
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
|
||||
class TestRejoinFlow:
|
||||
"""Leave and rejoin: state cleanup and fresh receiver."""
|
||||
|
||||
def test_stop_then_new_receiver_clean_state(self):
|
||||
"""After stop(), a new receiver starts with empty state."""
|
||||
key = _make_secret_key()
|
||||
receiver1 = _make_voice_receiver(key)
|
||||
receiver1.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver1._on_packet(packet)
|
||||
|
||||
assert len(receiver1._buffers[100]) > 0
|
||||
receiver1.stop()
|
||||
|
||||
# New receiver (simulates rejoin)
|
||||
receiver2 = _make_voice_receiver(key)
|
||||
assert len(receiver2._buffers) == 0
|
||||
assert len(receiver2._ssrc_to_user) == 0
|
||||
assert len(receiver2._decoders) == 0
|
||||
|
||||
def test_rejoin_new_ssrc_works(self):
|
||||
"""After rejoin, user may get new SSRC — still works."""
|
||||
key = _make_secret_key()
|
||||
receiver1 = _make_voice_receiver(key)
|
||||
receiver1.map_ssrc(100, 42) # old SSRC
|
||||
receiver1.stop()
|
||||
|
||||
receiver2 = _make_voice_receiver(key)
|
||||
receiver2.map_ssrc(200, 42) # new SSRC after rejoin
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=200, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver2._on_packet(packet)
|
||||
|
||||
receiver2._last_packet_time[200] = time.monotonic() - 3.0
|
||||
completed = receiver2.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
def test_rejoin_without_speaking_event_automap(self):
|
||||
"""Rejoin without SPEAKING event — auto-map sole allowed user."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
|
||||
# First session
|
||||
receiver1 = _make_voice_receiver(
|
||||
key, allowed_user_ids={"42"}, members=members,
|
||||
)
|
||||
receiver1.stop()
|
||||
|
||||
# Rejoin — new key (Discord may assign new secret_key)
|
||||
new_key = _make_secret_key()
|
||||
receiver2 = _make_voice_receiver(
|
||||
new_key, allowed_user_ids={"42"}, members=members,
|
||||
)
|
||||
# No map_ssrc — simulating missing SPEAKING event
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
new_key, b'\xf8\xff\xfe', ssrc=300, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver2._on_packet(packet)
|
||||
|
||||
receiver2._last_packet_time[300] = time.monotonic() - 3.0
|
||||
completed = receiver2.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
|
||||
class TestMultiGuildIsolation:
|
||||
"""Each guild has independent voice state."""
|
||||
|
||||
def test_separate_receivers_independent(self):
|
||||
"""Two receivers (different guilds) don't interfere."""
|
||||
key1 = _make_secret_key()
|
||||
key2 = _make_secret_key()
|
||||
|
||||
receiver1 = _make_voice_receiver(key1, bot_ssrc=1111)
|
||||
receiver2 = _make_voice_receiver(key2, bot_ssrc=2222)
|
||||
|
||||
receiver1.map_ssrc(100, 42)
|
||||
receiver2.map_ssrc(200, 99)
|
||||
|
||||
# Send to receiver1
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key1, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver1._on_packet(packet)
|
||||
|
||||
# receiver2 should be empty
|
||||
assert len(receiver2._buffers) == 0
|
||||
assert 100 in receiver1._buffers
|
||||
|
||||
def test_stop_one_doesnt_affect_other(self):
|
||||
"""Stopping one receiver doesn't affect another."""
|
||||
key1 = _make_secret_key()
|
||||
key2 = _make_secret_key()
|
||||
|
||||
receiver1 = _make_voice_receiver(key1)
|
||||
receiver2 = _make_voice_receiver(key2)
|
||||
|
||||
receiver1.map_ssrc(100, 42)
|
||||
receiver2.map_ssrc(200, 99)
|
||||
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key2, b'\xf8\xff\xfe', ssrc=200, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver2._on_packet(packet)
|
||||
|
||||
receiver1.stop()
|
||||
|
||||
# receiver2 still has data
|
||||
assert receiver2._running is True
|
||||
assert len(receiver2._buffers[200]) > 0
|
||||
|
||||
|
||||
class TestEchoPreventionFlow:
|
||||
"""Receiver pause/resume during TTS playback prevents echo."""
|
||||
|
||||
def test_audio_during_pause_ignored(self):
|
||||
"""Audio arriving while paused is completely ignored."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
receiver.pause()
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
def test_audio_after_resume_processed(self):
|
||||
"""Audio arriving after resume is processed normally."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
# Pause → send packets → resume → send more packets
|
||||
receiver.pause()
|
||||
for seq in range(1, 5):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
receiver.resume()
|
||||
for seq in range(5, 35):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
@@ -495,6 +495,59 @@ class TestConvertMessages:
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_converts_user_image_url_blocks_to_anthropic_image_blocks(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you see this?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you see this?"},
|
||||
{"type": "image", "source": {"type": "url", "url": "https://example.com/cat.png"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def test_converts_data_url_image_blocks_to_base64_anthropic_image_blocks(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "input_text", "text": "What is in this screenshot?"},
|
||||
{"type": "input_image", "image_url": "data:image/png;base64,AAAA"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is in this screenshot?"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "AAAA",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def test_converts_tool_calls(self):
|
||||
messages = [
|
||||
{
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def _tool_call(name: str, arguments):
|
||||
return SimpleNamespace(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=SimpleNamespace(name=name, arguments=arguments),
|
||||
)
|
||||
|
||||
|
||||
def _response_with_tool_call(arguments):
|
||||
assistant = SimpleNamespace(
|
||||
content=None,
|
||||
reasoning=None,
|
||||
tool_calls=[_tool_call("read_file", arguments)],
|
||||
)
|
||||
choice = SimpleNamespace(message=assistant, finish_reason="tool_calls")
|
||||
return SimpleNamespace(choices=[choice], usage=None)
|
||||
|
||||
|
||||
class _FakeChatCompletions:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
def create(self, **kwargs):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
return _response_with_tool_call({"path": "README.md"})
|
||||
return SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
message=SimpleNamespace(content="done", reasoning=None, tool_calls=[]),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self):
|
||||
self.chat = SimpleNamespace(completions=_FakeChatCompletions())
|
||||
|
||||
|
||||
def test_tool_call_validation_accepts_dict_arguments(monkeypatch):
|
||||
from run_agent import AIAgent
|
||||
|
||||
monkeypatch.setattr("run_agent.OpenAI", lambda **kwargs: _FakeClient())
|
||||
monkeypatch.setattr(
|
||||
"run_agent.get_tool_definitions",
|
||||
lambda *args, **kwargs: [{"function": {"name": "read_file"}}],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"run_agent.handle_function_call",
|
||||
lambda name, args, task_id=None, **kwargs: json.dumps({"ok": True, "args": args}),
|
||||
)
|
||||
|
||||
agent = AIAgent(
|
||||
model="test-model",
|
||||
api_key="test-key",
|
||||
base_url="http://localhost:8080/v1",
|
||||
platform="cli",
|
||||
max_iterations=3,
|
||||
quiet_mode=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
result = agent.run_conversation("read the file")
|
||||
|
||||
assert result["final_response"] == "done"
|
||||
@@ -361,6 +361,24 @@ class TestDeleteAndExport:
|
||||
def test_delete_nonexistent(self, db):
|
||||
assert db.delete_session("nope") is False
|
||||
|
||||
def test_resolve_session_id_exact(self, db):
|
||||
db.create_session(session_id="20260315_092437_c9a6ff", source="cli")
|
||||
assert db.resolve_session_id("20260315_092437_c9a6ff") == "20260315_092437_c9a6ff"
|
||||
|
||||
def test_resolve_session_id_unique_prefix(self, db):
|
||||
db.create_session(session_id="20260315_092437_c9a6ff", source="cli")
|
||||
assert db.resolve_session_id("20260315_092437_c9a6") == "20260315_092437_c9a6ff"
|
||||
|
||||
def test_resolve_session_id_ambiguous_prefix_returns_none(self, db):
|
||||
db.create_session(session_id="20260315_092437_c9a6aa", source="cli")
|
||||
db.create_session(session_id="20260315_092437_c9a6bb", source="cli")
|
||||
assert db.resolve_session_id("20260315_092437_c9a6") is None
|
||||
|
||||
def test_resolve_session_id_escapes_like_wildcards(self, db):
|
||||
db.create_session(session_id="20260315_092437_c9a6ff", source="cli")
|
||||
db.create_session(session_id="20260315X092437_c9a6ff", source="cli")
|
||||
assert db.resolve_session_id("20260315_092437") == "20260315_092437_c9a6ff"
|
||||
|
||||
def test_export_session(self, db):
|
||||
db.create_session(session_id="s1", source="cli", model="test")
|
||||
db.append_message("s1", role="user", content="Hello")
|
||||
|
||||
@@ -543,7 +543,7 @@ class TestAuxiliaryClientProviderPriority:
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value="codex-tok"), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_text_auxiliary_client()
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
|
||||
|
||||
|
||||
+64
-1
@@ -12,7 +12,7 @@ import uuid
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -1986,6 +1986,69 @@ class TestBuildApiKwargsAnthropicMaxTokens:
|
||||
assert call_args[0][3] is None
|
||||
|
||||
|
||||
class TestAnthropicImageFallback:
|
||||
def test_build_api_kwargs_converts_multimodal_user_image_to_text(self, agent):
|
||||
agent.api_mode = "anthropic_messages"
|
||||
agent.reasoning_config = None
|
||||
|
||||
api_messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you see this now?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png"}},
|
||||
],
|
||||
}]
|
||||
|
||||
with (
|
||||
patch("tools.vision_tools.vision_analyze_tool", new=AsyncMock(return_value=json.dumps({"success": True, "analysis": "A cat sitting on a chair."}))),
|
||||
patch("agent.anthropic_adapter.build_anthropic_kwargs") as mock_build,
|
||||
):
|
||||
mock_build.return_value = {"model": "claude-sonnet-4-20250514", "messages": [], "max_tokens": 4096}
|
||||
agent._build_api_kwargs(api_messages)
|
||||
|
||||
kwargs = mock_build.call_args.kwargs or dict(zip(
|
||||
["model", "messages", "tools", "max_tokens", "reasoning_config"],
|
||||
mock_build.call_args.args,
|
||||
))
|
||||
transformed = kwargs["messages"]
|
||||
assert isinstance(transformed[0]["content"], str)
|
||||
assert "A cat sitting on a chair." in transformed[0]["content"]
|
||||
assert "Can you see this now?" in transformed[0]["content"]
|
||||
assert "vision_analyze with image_url: https://example.com/cat.png" in transformed[0]["content"]
|
||||
|
||||
def test_build_api_kwargs_reuses_cached_image_analysis_for_duplicate_images(self, agent):
|
||||
agent.api_mode = "anthropic_messages"
|
||||
agent.reasoning_config = None
|
||||
data_url = "data:image/png;base64,QUFBQQ=="
|
||||
|
||||
api_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "first"},
|
||||
{"type": "input_image", "image_url": data_url},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "second"},
|
||||
{"type": "input_image", "image_url": data_url},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
mock_vision = AsyncMock(return_value=json.dumps({"success": True, "analysis": "A small test image."}))
|
||||
with (
|
||||
patch("tools.vision_tools.vision_analyze_tool", new=mock_vision),
|
||||
patch("agent.anthropic_adapter.build_anthropic_kwargs") as mock_build,
|
||||
):
|
||||
mock_build.return_value = {"model": "claude-sonnet-4-20250514", "messages": [], "max_tokens": 4096}
|
||||
agent._build_api_kwargs(api_messages)
|
||||
|
||||
assert mock_vision.await_count == 1
|
||||
|
||||
|
||||
class TestFallbackAnthropicProvider:
|
||||
"""Bug fix: _try_activate_fallback had no case for anthropic provider."""
|
||||
|
||||
|
||||
@@ -153,6 +153,36 @@ class TestScheduleCronjob:
|
||||
assert job["provider"] == "custom"
|
||||
assert job["base_url"] == "http://127.0.0.1:4000/v1"
|
||||
|
||||
def test_thread_id_captured_in_origin(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "123456")
|
||||
monkeypatch.setenv("HERMES_SESSION_THREAD_ID", "42")
|
||||
import cron.jobs as _jobs
|
||||
created = json.loads(schedule_cronjob(
|
||||
prompt="Thread test",
|
||||
schedule="every 1h",
|
||||
deliver="origin",
|
||||
))
|
||||
assert created["success"] is True
|
||||
job_id = created["job_id"]
|
||||
job = _jobs.get_job(job_id)
|
||||
assert job["origin"]["thread_id"] == "42"
|
||||
|
||||
def test_thread_id_absent_when_not_set(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "123456")
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
import cron.jobs as _jobs
|
||||
created = json.loads(schedule_cronjob(
|
||||
prompt="No thread test",
|
||||
schedule="every 1h",
|
||||
deliver="origin",
|
||||
))
|
||||
assert created["success"] is True
|
||||
job_id = created["job_id"]
|
||||
job = _jobs.get_job(job_id)
|
||||
assert job["origin"].get("thread_id") is None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# list_cronjobs
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Tests for provider env var blocklist in LocalEnvironment.
|
||||
"""Tests for subprocess env sanitization in LocalEnvironment.
|
||||
|
||||
Verifies that Hermes-internal provider env vars (OPENAI_BASE_URL, etc.)
|
||||
are stripped from subprocess environments so external CLIs are not
|
||||
silently misrouted.
|
||||
Verifies that Hermes-managed provider, tool, and gateway env vars are
|
||||
stripped from subprocess environments so external CLIs are not silently
|
||||
misrouted or handed Hermes secrets.
|
||||
|
||||
See: https://github.com/NousResearch/hermes-agent/issues/1002
|
||||
See: https://github.com/NousResearch/hermes-agent/issues/1264
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -25,8 +26,7 @@ def _make_fake_popen(captured: dict):
|
||||
proc = MagicMock()
|
||||
proc.poll.return_value = 0
|
||||
proc.returncode = 0
|
||||
proc.stdout = iter([])
|
||||
proc.stdout.close = lambda: None
|
||||
proc.stdout = MagicMock(__iter__=lambda s: iter([]), __next__=lambda s: (_ for _ in ()).throw(StopIteration))
|
||||
proc.stdin = MagicMock()
|
||||
return proc
|
||||
return fake_popen
|
||||
@@ -110,6 +110,30 @@ class TestProviderEnvBlocklist:
|
||||
for var in extra_provider_vars:
|
||||
assert var not in result_env, f"{var} leaked into subprocess env"
|
||||
|
||||
def test_tool_and_gateway_vars_are_stripped(self):
|
||||
"""Tool and gateway secrets/config must not leak into subprocess env."""
|
||||
leaked_vars = {
|
||||
"TELEGRAM_BOT_TOKEN": "bot-token",
|
||||
"TELEGRAM_HOME_CHANNEL": "12345",
|
||||
"DISCORD_HOME_CHANNEL": "67890",
|
||||
"SLACK_APP_TOKEN": "xapp-secret",
|
||||
"WHATSAPP_ALLOWED_USERS": "+15555550123",
|
||||
"SIGNAL_ACCOUNT": "+15555550124",
|
||||
"HASS_TOKEN": "ha-secret",
|
||||
"EMAIL_PASSWORD": "email-secret",
|
||||
"FIRECRAWL_API_KEY": "fc-secret",
|
||||
"BROWSERBASE_PROJECT_ID": "bb-project",
|
||||
"ELEVENLABS_API_KEY": "el-secret",
|
||||
"GITHUB_TOKEN": "ghp_secret",
|
||||
"GH_TOKEN": "gh_alias_secret",
|
||||
"GATEWAY_ALLOW_ALL_USERS": "true",
|
||||
"GATEWAY_ALLOWED_USERS": "alice,bob",
|
||||
}
|
||||
result_env = _run_with_env(extra_os_env=leaked_vars)
|
||||
|
||||
for var in leaked_vars:
|
||||
assert var not in result_env, f"{var} leaked into subprocess env"
|
||||
|
||||
def test_safe_vars_are_preserved(self):
|
||||
"""Standard env vars (PATH, HOME, USER) must still be passed through."""
|
||||
result_env = _run_with_env()
|
||||
@@ -205,3 +229,56 @@ class TestBlocklistCoverage:
|
||||
"HELICONE_API_KEY",
|
||||
}
|
||||
assert extras.issubset(_HERMES_PROVIDER_ENV_BLOCKLIST)
|
||||
|
||||
def test_optional_tool_and_messaging_vars_are_in_blocklist(self):
|
||||
"""Tool/messaging vars from OPTIONAL_ENV_VARS should stay covered."""
|
||||
from hermes_cli.config import OPTIONAL_ENV_VARS
|
||||
|
||||
for name, metadata in OPTIONAL_ENV_VARS.items():
|
||||
category = metadata.get("category")
|
||||
if category in {"tool", "messaging"}:
|
||||
assert name in _HERMES_PROVIDER_ENV_BLOCKLIST, (
|
||||
f"Optional env var {name} (category={category}) missing from blocklist"
|
||||
)
|
||||
elif category == "setting" and metadata.get("password"):
|
||||
assert name in _HERMES_PROVIDER_ENV_BLOCKLIST, (
|
||||
f"Secret setting env var {name} missing from blocklist"
|
||||
)
|
||||
|
||||
def test_gateway_runtime_vars_are_in_blocklist(self):
|
||||
extras = {
|
||||
"TELEGRAM_HOME_CHANNEL",
|
||||
"TELEGRAM_HOME_CHANNEL_NAME",
|
||||
"DISCORD_HOME_CHANNEL",
|
||||
"DISCORD_HOME_CHANNEL_NAME",
|
||||
"DISCORD_REQUIRE_MENTION",
|
||||
"DISCORD_FREE_RESPONSE_CHANNELS",
|
||||
"DISCORD_AUTO_THREAD",
|
||||
"SLACK_HOME_CHANNEL",
|
||||
"SLACK_HOME_CHANNEL_NAME",
|
||||
"SLACK_ALLOWED_USERS",
|
||||
"WHATSAPP_ENABLED",
|
||||
"WHATSAPP_MODE",
|
||||
"WHATSAPP_ALLOWED_USERS",
|
||||
"SIGNAL_HTTP_URL",
|
||||
"SIGNAL_ACCOUNT",
|
||||
"SIGNAL_ALLOWED_USERS",
|
||||
"SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"SIGNAL_HOME_CHANNEL",
|
||||
"SIGNAL_HOME_CHANNEL_NAME",
|
||||
"SIGNAL_IGNORE_STORIES",
|
||||
"HASS_TOKEN",
|
||||
"HASS_URL",
|
||||
"EMAIL_ADDRESS",
|
||||
"EMAIL_PASSWORD",
|
||||
"EMAIL_IMAP_HOST",
|
||||
"EMAIL_SMTP_HOST",
|
||||
"EMAIL_HOME_ADDRESS",
|
||||
"EMAIL_HOME_ADDRESS_NAME",
|
||||
"GATEWAY_ALLOWED_USERS",
|
||||
"GH_TOKEN",
|
||||
"GITHUB_APP_ID",
|
||||
"GITHUB_APP_PRIVATE_KEY_PATH",
|
||||
"GITHUB_APP_INSTALLATION_ID",
|
||||
}
|
||||
assert extras.issubset(_HERMES_PROVIDER_ENV_BLOCKLIST)
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
"""Tests for the local persistent shell backend."""
|
||||
|
||||
import glob as glob_mod
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.local import LocalEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
|
||||
|
||||
class TestLocalConfig:
|
||||
def test_local_persistent_default_false(self, monkeypatch):
|
||||
monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False)
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is False
|
||||
|
||||
def test_local_persistent_true(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is True
|
||||
|
||||
def test_local_persistent_yes(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is True
|
||||
|
||||
|
||||
class TestMergeOutput:
|
||||
def test_stdout_only(self):
|
||||
assert PersistentShellMixin._merge_output("out", "") == "out"
|
||||
|
||||
def test_stderr_only(self):
|
||||
assert PersistentShellMixin._merge_output("", "err") == "err"
|
||||
|
||||
def test_both(self):
|
||||
assert PersistentShellMixin._merge_output("out", "err") == "out\nerr"
|
||||
|
||||
def test_empty(self):
|
||||
assert PersistentShellMixin._merge_output("", "") == ""
|
||||
|
||||
def test_strips_trailing_newlines(self):
|
||||
assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr"
|
||||
|
||||
|
||||
class TestLocalOneShotRegression:
|
||||
def test_echo(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
r = env.execute("echo hello")
|
||||
assert r["returncode"] == 0
|
||||
assert "hello" in r["output"]
|
||||
env.cleanup()
|
||||
|
||||
def test_exit_code(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
r = env.execute("exit 42")
|
||||
assert r["returncode"] == 42
|
||||
env.cleanup()
|
||||
|
||||
def test_state_does_not_persist(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
env.execute("export HERMES_ONESHOT_LOCAL=yes")
|
||||
r = env.execute("echo $HERMES_ONESHOT_LOCAL")
|
||||
assert r["output"].strip() == ""
|
||||
env.cleanup()
|
||||
|
||||
|
||||
class TestLocalPersistent:
|
||||
@pytest.fixture
|
||||
def env(self):
|
||||
e = LocalEnvironment(persistent=True)
|
||||
yield e
|
||||
e.cleanup()
|
||||
|
||||
def test_echo(self, env):
|
||||
r = env.execute("echo hello-persistent")
|
||||
assert r["returncode"] == 0
|
||||
assert "hello-persistent" in r["output"]
|
||||
|
||||
def test_env_var_persists(self, env):
|
||||
env.execute("export HERMES_LOCAL_PERSIST_TEST=works")
|
||||
r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST")
|
||||
assert r["output"].strip() == "works"
|
||||
|
||||
def test_cwd_persists(self, env):
|
||||
env.execute("cd /tmp")
|
||||
r = env.execute("pwd")
|
||||
assert r["output"].strip() == "/tmp"
|
||||
|
||||
def test_exit_code(self, env):
|
||||
r = env.execute("(exit 42)")
|
||||
assert r["returncode"] == 42
|
||||
|
||||
def test_stderr(self, env):
|
||||
r = env.execute("echo oops >&2")
|
||||
assert r["returncode"] == 0
|
||||
assert "oops" in r["output"]
|
||||
|
||||
def test_multiline_output(self, env):
|
||||
r = env.execute("echo a; echo b; echo c")
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert lines == ["a", "b", "c"]
|
||||
|
||||
def test_timeout_then_recovery(self, env):
|
||||
r = env.execute("sleep 999", timeout=2)
|
||||
assert r["returncode"] in (124, 130)
|
||||
r = env.execute("echo alive")
|
||||
assert r["returncode"] == 0
|
||||
assert "alive" in r["output"]
|
||||
|
||||
def test_large_output(self, env):
|
||||
r = env.execute("seq 1 1000")
|
||||
assert r["returncode"] == 0
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert len(lines) == 1000
|
||||
assert lines[0] == "1"
|
||||
assert lines[-1] == "1000"
|
||||
|
||||
def test_shell_variable_persists(self, env):
|
||||
env.execute("MY_LOCAL_VAR=hello123")
|
||||
r = env.execute("echo $MY_LOCAL_VAR")
|
||||
assert r["output"].strip() == "hello123"
|
||||
|
||||
def test_cleanup_removes_temp_files(self, env):
|
||||
env.execute("echo warmup")
|
||||
prefix = env._temp_prefix
|
||||
assert len(glob_mod.glob(f"{prefix}-*")) > 0
|
||||
env.cleanup()
|
||||
remaining = glob_mod.glob(f"{prefix}-*")
|
||||
assert remaining == []
|
||||
|
||||
def test_state_does_not_leak_between_instances(self):
|
||||
env1 = LocalEnvironment(persistent=True)
|
||||
env2 = LocalEnvironment(persistent=True)
|
||||
try:
|
||||
env1.execute("export LEAK_TEST=from_env1")
|
||||
r = env2.execute("echo $LEAK_TEST")
|
||||
assert r["output"].strip() == ""
|
||||
finally:
|
||||
env1.cleanup()
|
||||
env2.cleanup()
|
||||
|
||||
def test_special_characters_in_command(self, env):
|
||||
r = env.execute("echo 'hello world'")
|
||||
assert r["output"].strip() == "hello world"
|
||||
|
||||
def test_pipe_command(self, env):
|
||||
r = env.execute("echo hello | tr 'h' 'H'")
|
||||
assert r["output"].strip() == "Hello"
|
||||
|
||||
def test_multiple_commands_semicolon(self, env):
|
||||
r = env.execute("X=42; echo $X")
|
||||
assert r["output"].strip() == "42"
|
||||
@@ -1,11 +1,13 @@
|
||||
"""Tests for tools/process_registry.py — ProcessRegistry query methods, pruning, checkpoint."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tools.environments.local import _HERMES_PROVIDER_ENV_FORCE_PREFIX
|
||||
from tools.process_registry import (
|
||||
ProcessRegistry,
|
||||
ProcessSession,
|
||||
@@ -213,6 +215,54 @@ class TestPruning:
|
||||
assert total <= MAX_PROCESSES
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Spawn env sanitization
|
||||
# =========================================================================
|
||||
|
||||
class TestSpawnEnvSanitization:
|
||||
def test_spawn_local_strips_blocked_vars_from_background_env(self, registry):
|
||||
captured = {}
|
||||
|
||||
def fake_popen(cmd, **kwargs):
|
||||
captured["env"] = kwargs["env"]
|
||||
proc = MagicMock()
|
||||
proc.pid = 4321
|
||||
proc.stdout = iter([])
|
||||
proc.stdin = MagicMock()
|
||||
proc.poll.return_value = None
|
||||
return proc
|
||||
|
||||
fake_thread = MagicMock()
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"PATH": "/usr/bin:/bin",
|
||||
"HOME": "/home/user",
|
||||
"USER": "tester",
|
||||
"TELEGRAM_BOT_TOKEN": "bot-secret",
|
||||
"FIRECRAWL_API_KEY": "fc-secret",
|
||||
}, clear=True), \
|
||||
patch("tools.process_registry._find_shell", return_value="/bin/bash"), \
|
||||
patch("subprocess.Popen", side_effect=fake_popen), \
|
||||
patch("threading.Thread", return_value=fake_thread), \
|
||||
patch.object(registry, "_write_checkpoint"):
|
||||
registry.spawn_local(
|
||||
"echo hello",
|
||||
cwd="/tmp",
|
||||
env_vars={
|
||||
"MY_CUSTOM_VAR": "keep-me",
|
||||
"TELEGRAM_BOT_TOKEN": "drop-me",
|
||||
f"{_HERMES_PROVIDER_ENV_FORCE_PREFIX}TELEGRAM_BOT_TOKEN": "forced-bot-token",
|
||||
},
|
||||
)
|
||||
|
||||
env = captured["env"]
|
||||
assert env["MY_CUSTOM_VAR"] == "keep-me"
|
||||
assert env["TELEGRAM_BOT_TOKEN"] == "forced-bot-token"
|
||||
assert "FIRECRAWL_API_KEY" not in env
|
||||
assert f"{_HERMES_PROVIDER_ENV_FORCE_PREFIX}TELEGRAM_BOT_TOKEN" not in env
|
||||
assert env["PYTHONUNBUFFERED"] == "1"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Checkpoint
|
||||
# =========================================================================
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from tools.skills_hub import ClawHubSource
|
||||
from tools.skills_hub import ClawHubSource, SkillMeta
|
||||
|
||||
|
||||
class _MockResponse:
|
||||
@@ -22,21 +22,31 @@ class TestClawHubSource(unittest.TestCase):
|
||||
|
||||
@patch("tools.skills_hub._write_index_cache")
|
||||
@patch("tools.skills_hub._read_index_cache", return_value=None)
|
||||
@patch.object(ClawHubSource, "_load_catalog_index", return_value=[])
|
||||
@patch("tools.skills_hub.httpx.get")
|
||||
def test_search_uses_new_endpoint_and_parses_items(self, mock_get, _mock_read_cache, _mock_write_cache):
|
||||
mock_get.return_value = _MockResponse(
|
||||
status_code=200,
|
||||
json_data={
|
||||
"items": [
|
||||
{
|
||||
"slug": "caldav-calendar",
|
||||
"displayName": "CalDAV Calendar",
|
||||
"summary": "Calendar integration",
|
||||
"tags": ["calendar", "productivity"],
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
def test_search_uses_listing_endpoint_as_fallback(
|
||||
self, mock_get, _mock_load_catalog, _mock_read_cache, _mock_write_cache
|
||||
):
|
||||
def side_effect(url, *args, **kwargs):
|
||||
if url.endswith("/skills"):
|
||||
return _MockResponse(
|
||||
status_code=200,
|
||||
json_data={
|
||||
"items": [
|
||||
{
|
||||
"slug": "caldav-calendar",
|
||||
"displayName": "CalDAV Calendar",
|
||||
"summary": "Calendar integration",
|
||||
"tags": ["calendar", "productivity"],
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
if url.endswith("/skills/caldav"):
|
||||
return _MockResponse(status_code=404, json_data={})
|
||||
return _MockResponse(status_code=404, json_data={})
|
||||
|
||||
mock_get.side_effect = side_effect
|
||||
|
||||
results = self.src.search("caldav", limit=5)
|
||||
|
||||
@@ -45,11 +55,112 @@ class TestClawHubSource(unittest.TestCase):
|
||||
self.assertEqual(results[0].name, "CalDAV Calendar")
|
||||
self.assertEqual(results[0].description, "Calendar integration")
|
||||
|
||||
mock_get.assert_called_once()
|
||||
args, kwargs = mock_get.call_args
|
||||
self.assertGreaterEqual(mock_get.call_count, 2)
|
||||
args, kwargs = mock_get.call_args_list[0]
|
||||
self.assertTrue(args[0].endswith("/skills"))
|
||||
self.assertEqual(kwargs["params"], {"search": "caldav", "limit": 5})
|
||||
|
||||
@patch("tools.skills_hub._write_index_cache")
|
||||
@patch("tools.skills_hub._read_index_cache", return_value=None)
|
||||
@patch.object(
|
||||
ClawHubSource,
|
||||
"_load_catalog_index",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("tools.skills_hub.httpx.get")
|
||||
def test_search_falls_back_to_exact_slug_when_search_results_are_irrelevant(
|
||||
self, mock_get, _mock_load_catalog, _mock_read_cache, _mock_write_cache
|
||||
):
|
||||
def side_effect(url, *args, **kwargs):
|
||||
if url.endswith("/skills"):
|
||||
return _MockResponse(
|
||||
status_code=200,
|
||||
json_data={
|
||||
"items": [
|
||||
{
|
||||
"slug": "apple-music-dj",
|
||||
"displayName": "Apple Music DJ",
|
||||
"summary": "Unrelated result",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
if url.endswith("/skills/self-improving-agent"):
|
||||
return _MockResponse(
|
||||
status_code=200,
|
||||
json_data={
|
||||
"skill": {
|
||||
"slug": "self-improving-agent",
|
||||
"displayName": "self-improving-agent",
|
||||
"summary": "Captures learnings and errors for continuous improvement.",
|
||||
"tags": {"latest": "3.0.2", "automation": "3.0.2"},
|
||||
},
|
||||
"latestVersion": {"version": "3.0.2"},
|
||||
},
|
||||
)
|
||||
return _MockResponse(status_code=404, json_data={})
|
||||
|
||||
mock_get.side_effect = side_effect
|
||||
|
||||
results = self.src.search("self-improving-agent", limit=5)
|
||||
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].identifier, "self-improving-agent")
|
||||
self.assertEqual(results[0].name, "self-improving-agent")
|
||||
self.assertIn("continuous improvement", results[0].description)
|
||||
|
||||
@patch("tools.skills_hub.httpx.get")
|
||||
def test_search_repairs_poisoned_cache_with_exact_slug_lookup(self, mock_get):
|
||||
mock_get.return_value = _MockResponse(
|
||||
status_code=200,
|
||||
json_data={
|
||||
"skill": {
|
||||
"slug": "self-improving-agent",
|
||||
"displayName": "self-improving-agent",
|
||||
"summary": "Captures learnings and errors for continuous improvement.",
|
||||
"tags": {"latest": "3.0.2", "automation": "3.0.2"},
|
||||
},
|
||||
"latestVersion": {"version": "3.0.2"},
|
||||
},
|
||||
)
|
||||
|
||||
poisoned = [
|
||||
SkillMeta(
|
||||
name="Apple Music DJ",
|
||||
description="Unrelated cached result",
|
||||
source="clawhub",
|
||||
identifier="apple-music-dj",
|
||||
trust_level="community",
|
||||
tags=[],
|
||||
)
|
||||
]
|
||||
results = self.src._finalize_search_results("self-improving-agent", poisoned, 5)
|
||||
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].identifier, "self-improving-agent")
|
||||
mock_get.assert_called_once()
|
||||
self.assertTrue(mock_get.call_args.args[0].endswith("/skills/self-improving-agent"))
|
||||
|
||||
@patch.object(
|
||||
ClawHubSource,
|
||||
"_exact_slug_meta",
|
||||
return_value=SkillMeta(
|
||||
name="self-improving-agent",
|
||||
description="Captures learnings and errors for continuous improvement.",
|
||||
source="clawhub",
|
||||
identifier="self-improving-agent",
|
||||
trust_level="community",
|
||||
tags=["automation"],
|
||||
),
|
||||
)
|
||||
def test_search_matches_space_separated_query_to_hyphenated_slug(
|
||||
self, _mock_exact_slug
|
||||
):
|
||||
results = self.src.search("self improving", limit=5)
|
||||
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].identifier, "self-improving-agent")
|
||||
|
||||
@patch("tools.skills_hub.httpx.get")
|
||||
def test_inspect_maps_display_name_and_summary(self, mock_get):
|
||||
mock_get.return_value = _MockResponse(
|
||||
@@ -69,6 +180,29 @@ class TestClawHubSource(unittest.TestCase):
|
||||
self.assertEqual(meta.description, "Calendar integration")
|
||||
self.assertEqual(meta.identifier, "caldav-calendar")
|
||||
|
||||
@patch("tools.skills_hub.httpx.get")
|
||||
def test_inspect_handles_nested_skill_payload(self, mock_get):
|
||||
mock_get.return_value = _MockResponse(
|
||||
status_code=200,
|
||||
json_data={
|
||||
"skill": {
|
||||
"slug": "self-improving-agent",
|
||||
"displayName": "self-improving-agent",
|
||||
"summary": "Captures learnings and errors for continuous improvement.",
|
||||
"tags": {"latest": "3.0.2", "automation": "3.0.2"},
|
||||
},
|
||||
"latestVersion": {"version": "3.0.2"},
|
||||
},
|
||||
)
|
||||
|
||||
meta = self.src.inspect("self-improving-agent")
|
||||
|
||||
self.assertIsNotNone(meta)
|
||||
self.assertEqual(meta.name, "self-improving-agent")
|
||||
self.assertIn("continuous improvement", meta.description)
|
||||
self.assertEqual(meta.identifier, "self-improving-agent")
|
||||
self.assertEqual(meta.tags, ["automation"])
|
||||
|
||||
@patch("tools.skills_hub.httpx.get")
|
||||
def test_fetch_resolves_latest_version_and_downloads_raw_files(self, mock_get):
|
||||
def side_effect(url, *args, **kwargs):
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Tests for the SSH remote execution environment backend."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
|
||||
_SSH_HOST = os.getenv("TERMINAL_SSH_HOST", "")
|
||||
_SSH_USER = os.getenv("TERMINAL_SSH_USER", "")
|
||||
_SSH_PORT = int(os.getenv("TERMINAL_SSH_PORT", "22"))
|
||||
_SSH_KEY = os.getenv("TERMINAL_SSH_KEY", "")
|
||||
|
||||
_has_ssh = bool(_SSH_HOST and _SSH_USER)
|
||||
|
||||
requires_ssh = pytest.mark.skipif(
|
||||
not _has_ssh,
|
||||
reason="TERMINAL_SSH_HOST / TERMINAL_SSH_USER not set",
|
||||
)
|
||||
|
||||
|
||||
def _run(command, task_id="ssh_test", **kwargs):
|
||||
from tools.terminal_tool import terminal_tool
|
||||
return json.loads(terminal_tool(command, task_id=task_id, **kwargs))
|
||||
|
||||
|
||||
def _cleanup(task_id="ssh_test"):
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
cleanup_vm(task_id)
|
||||
|
||||
|
||||
class TestBuildSSHCommand:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_connection(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.environments.ssh.subprocess.run",
|
||||
lambda *a, **k: subprocess.CompletedProcess([], 0))
|
||||
monkeypatch.setattr("tools.environments.ssh.subprocess.Popen",
|
||||
lambda *a, **k: MagicMock(stdout=iter([]),
|
||||
stderr=iter([]),
|
||||
stdin=MagicMock()))
|
||||
monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None)
|
||||
|
||||
def test_base_flags(self):
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
cmd = " ".join(env._build_ssh_command())
|
||||
for flag in ("ControlMaster=auto", "ControlPersist=300",
|
||||
"BatchMode=yes", "StrictHostKeyChecking=accept-new"):
|
||||
assert flag in cmd
|
||||
|
||||
def test_custom_port(self):
|
||||
env = SSHEnvironment(host="h", user="u", port=2222)
|
||||
cmd = env._build_ssh_command()
|
||||
assert "-p" in cmd and "2222" in cmd
|
||||
|
||||
def test_key_path(self):
|
||||
env = SSHEnvironment(host="h", user="u", key_path="/k")
|
||||
cmd = env._build_ssh_command()
|
||||
assert "-i" in cmd and "/k" in cmd
|
||||
|
||||
def test_user_host_suffix(self):
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
assert env._build_ssh_command()[-1] == "u@h"
|
||||
|
||||
|
||||
class TestTerminalToolConfig:
|
||||
def test_ssh_persistent_default_false(self, monkeypatch):
|
||||
monkeypatch.delenv("TERMINAL_SSH_PERSISTENT", raising=False)
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["ssh_persistent"] is False
|
||||
|
||||
def test_ssh_persistent_true(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["ssh_persistent"] is True
|
||||
|
||||
|
||||
def _setup_ssh_env(monkeypatch, persistent: bool):
|
||||
monkeypatch.setenv("TERMINAL_ENV", "ssh")
|
||||
monkeypatch.setenv("TERMINAL_SSH_HOST", _SSH_HOST)
|
||||
monkeypatch.setenv("TERMINAL_SSH_USER", _SSH_USER)
|
||||
monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true" if persistent else "false")
|
||||
if _SSH_PORT != 22:
|
||||
monkeypatch.setenv("TERMINAL_SSH_PORT", str(_SSH_PORT))
|
||||
if _SSH_KEY:
|
||||
monkeypatch.setenv("TERMINAL_SSH_KEY", _SSH_KEY)
|
||||
|
||||
|
||||
@requires_ssh
|
||||
class TestOneShotSSH:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self, monkeypatch):
|
||||
_setup_ssh_env(monkeypatch, persistent=False)
|
||||
yield
|
||||
_cleanup()
|
||||
|
||||
def test_echo(self):
|
||||
r = _run("echo hello")
|
||||
assert r["exit_code"] == 0
|
||||
assert "hello" in r["output"]
|
||||
|
||||
def test_exit_code(self):
|
||||
r = _run("exit 42")
|
||||
assert r["exit_code"] == 42
|
||||
|
||||
def test_state_does_not_persist(self):
|
||||
_run("export HERMES_ONESHOT_TEST=yes")
|
||||
r = _run("echo $HERMES_ONESHOT_TEST")
|
||||
assert r["output"].strip() == ""
|
||||
|
||||
|
||||
@requires_ssh
|
||||
class TestPersistentSSH:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self, monkeypatch):
|
||||
_setup_ssh_env(monkeypatch, persistent=True)
|
||||
yield
|
||||
_cleanup()
|
||||
|
||||
def test_echo(self):
|
||||
r = _run("echo hello-persistent")
|
||||
assert r["exit_code"] == 0
|
||||
assert "hello-persistent" in r["output"]
|
||||
|
||||
def test_env_var_persists(self):
|
||||
_run("export HERMES_PERSIST_TEST=works")
|
||||
r = _run("echo $HERMES_PERSIST_TEST")
|
||||
assert r["output"].strip() == "works"
|
||||
|
||||
def test_cwd_persists(self):
|
||||
_run("cd /tmp")
|
||||
r = _run("pwd")
|
||||
assert r["output"].strip() == "/tmp"
|
||||
|
||||
def test_exit_code(self):
|
||||
r = _run("(exit 42)")
|
||||
assert r["exit_code"] == 42
|
||||
|
||||
def test_stderr(self):
|
||||
r = _run("echo oops >&2")
|
||||
assert r["exit_code"] == 0
|
||||
assert "oops" in r["output"]
|
||||
|
||||
def test_multiline_output(self):
|
||||
r = _run("echo a; echo b; echo c")
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert lines == ["a", "b", "c"]
|
||||
|
||||
def test_timeout_then_recovery(self):
|
||||
r = _run("sleep 999", timeout=2)
|
||||
assert r["exit_code"] == 124
|
||||
r = _run("echo alive")
|
||||
assert r["exit_code"] == 0
|
||||
assert "alive" in r["output"]
|
||||
|
||||
def test_large_output(self):
|
||||
r = _run("seq 1 1000")
|
||||
assert r["exit_code"] == 0
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert len(lines) == 1000
|
||||
assert lines[0] == "1"
|
||||
assert lines[-1] == "1000"
|
||||
@@ -72,6 +72,7 @@ def _origin_from_env() -> Optional[Dict[str, str]]:
|
||||
"platform": origin_platform,
|
||||
"chat_id": origin_chat_id,
|
||||
"chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"),
|
||||
"thread_id": os.getenv("HERMES_SESSION_THREAD_ID"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
+224
-119
@@ -1,5 +1,6 @@
|
||||
"""Local execution environment with interrupt support and non-blocking I/O."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
@@ -11,6 +12,8 @@ import time
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
# Unique marker to isolate real command output from shell init/exit noise.
|
||||
# printf (no trailing newline) keeps the boundaries clean for splitting.
|
||||
@@ -27,11 +30,12 @@ _HERMES_PROVIDER_ENV_FORCE_PREFIX = "_HERMES_FORCE_"
|
||||
|
||||
|
||||
def _build_provider_env_blocklist() -> frozenset:
|
||||
"""Derive the blocklist from the provider registry + known extras.
|
||||
"""Derive the blocklist from provider, tool, and gateway config.
|
||||
|
||||
Automatically picks up api_key_env_vars and base_url_env_var from
|
||||
every registered provider, so adding a new provider to auth.py is
|
||||
enough — no manual list to keep in sync.
|
||||
every registered provider, plus tool/messaging env vars from the
|
||||
optional config registry, so new Hermes-managed secrets are blocked
|
||||
in subprocesses without having to maintain multiple static lists.
|
||||
"""
|
||||
blocked: set[str] = set()
|
||||
|
||||
@@ -44,7 +48,18 @@ def _build_provider_env_blocklist() -> frozenset:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Vars not in the registry but still Hermes-internal / conflict-prone
|
||||
try:
|
||||
from hermes_cli.config import OPTIONAL_ENV_VARS
|
||||
for name, metadata in OPTIONAL_ENV_VARS.items():
|
||||
category = metadata.get("category")
|
||||
if category in {"tool", "messaging"}:
|
||||
blocked.add(name)
|
||||
elif category == "setting" and metadata.get("password"):
|
||||
blocked.add(name)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Vars not covered above but still Hermes-internal / conflict-prone.
|
||||
blocked.update({
|
||||
"OPENAI_BASE_URL",
|
||||
"OPENAI_API_KEY",
|
||||
@@ -67,6 +82,41 @@ def _build_provider_env_blocklist() -> frozenset:
|
||||
"FIREWORKS_API_KEY", # Fireworks AI
|
||||
"XAI_API_KEY", # xAI (Grok)
|
||||
"HELICONE_API_KEY", # LLM Observability proxy
|
||||
# Gateway/runtime config not represented in OPTIONAL_ENV_VARS.
|
||||
"TELEGRAM_HOME_CHANNEL",
|
||||
"TELEGRAM_HOME_CHANNEL_NAME",
|
||||
"DISCORD_HOME_CHANNEL",
|
||||
"DISCORD_HOME_CHANNEL_NAME",
|
||||
"DISCORD_REQUIRE_MENTION",
|
||||
"DISCORD_FREE_RESPONSE_CHANNELS",
|
||||
"DISCORD_AUTO_THREAD",
|
||||
"SLACK_HOME_CHANNEL",
|
||||
"SLACK_HOME_CHANNEL_NAME",
|
||||
"SLACK_ALLOWED_USERS",
|
||||
"WHATSAPP_ENABLED",
|
||||
"WHATSAPP_MODE",
|
||||
"WHATSAPP_ALLOWED_USERS",
|
||||
"SIGNAL_HTTP_URL",
|
||||
"SIGNAL_ACCOUNT",
|
||||
"SIGNAL_ALLOWED_USERS",
|
||||
"SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"SIGNAL_HOME_CHANNEL",
|
||||
"SIGNAL_HOME_CHANNEL_NAME",
|
||||
"SIGNAL_IGNORE_STORIES",
|
||||
"HASS_TOKEN",
|
||||
"HASS_URL",
|
||||
"EMAIL_ADDRESS",
|
||||
"EMAIL_PASSWORD",
|
||||
"EMAIL_IMAP_HOST",
|
||||
"EMAIL_SMTP_HOST",
|
||||
"EMAIL_HOME_ADDRESS",
|
||||
"EMAIL_HOME_ADDRESS_NAME",
|
||||
"GATEWAY_ALLOWED_USERS",
|
||||
# Skills Hub / GitHub app auth paths and aliases.
|
||||
"GH_TOKEN",
|
||||
"GITHUB_APP_ID",
|
||||
"GITHUB_APP_PRIVATE_KEY_PATH",
|
||||
"GITHUB_APP_INSTALLATION_ID",
|
||||
})
|
||||
return frozenset(blocked)
|
||||
|
||||
@@ -74,6 +124,30 @@ def _build_provider_env_blocklist() -> frozenset:
|
||||
_HERMES_PROVIDER_ENV_BLOCKLIST = _build_provider_env_blocklist()
|
||||
|
||||
|
||||
def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = None) -> dict:
|
||||
"""Filter Hermes-managed secrets from a subprocess environment.
|
||||
|
||||
`_HERMES_FORCE_<VAR>` entries in ``extra_env`` opt a blocked variable back in
|
||||
intentionally for callers that truly need it.
|
||||
"""
|
||||
sanitized: dict[str, str] = {}
|
||||
|
||||
for key, value in (base_env or {}).items():
|
||||
if key.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
|
||||
continue
|
||||
if key not in _HERMES_PROVIDER_ENV_BLOCKLIST:
|
||||
sanitized[key] = value
|
||||
|
||||
for key, value in (extra_env or {}).items():
|
||||
if key.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
|
||||
real_key = key[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):]
|
||||
sanitized[real_key] = value
|
||||
elif key not in _HERMES_PROVIDER_ENV_BLOCKLIST:
|
||||
sanitized[key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _find_bash() -> str:
|
||||
"""Find bash for command execution.
|
||||
|
||||
@@ -173,6 +247,25 @@ def _clean_shell_noise(output: str) -> str:
|
||||
return result
|
||||
|
||||
|
||||
_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
|
||||
|
||||
def _make_run_env(env: dict) -> dict:
|
||||
"""Build a run environment with a sane PATH and provider-var stripping."""
|
||||
merged = dict(os.environ | env)
|
||||
run_env = {}
|
||||
for k, v in merged.items():
|
||||
if k.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
|
||||
real_key = k[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):]
|
||||
run_env[real_key] = v
|
||||
elif k not in _HERMES_PROVIDER_ENV_BLOCKLIST:
|
||||
run_env[k] = v
|
||||
existing_path = run_env.get("PATH", "")
|
||||
if "/usr/bin" not in existing_path.split(":"):
|
||||
run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH
|
||||
return run_env
|
||||
|
||||
|
||||
def _extract_fenced_output(raw: str) -> str:
|
||||
"""Extract real command output from between fence markers.
|
||||
|
||||
@@ -197,7 +290,7 @@ def _extract_fenced_output(raw: str) -> str:
|
||||
return raw[start:last]
|
||||
|
||||
|
||||
class LocalEnvironment(BaseEnvironment):
|
||||
class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
"""Run commands directly on the host machine.
|
||||
|
||||
Features:
|
||||
@@ -206,24 +299,66 @@ class LocalEnvironment(BaseEnvironment):
|
||||
- stdin_data support for piping content (bypasses ARG_MAX limits)
|
||||
- sudo -S transform via SUDO_PASSWORD env var
|
||||
- Uses interactive login shell so full user env is available
|
||||
- Optional persistent shell mode (cwd/env vars survive across calls)
|
||||
"""
|
||||
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None,
|
||||
persistent: bool = False):
|
||||
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
|
||||
self.persistent = persistent
|
||||
if self.persistent:
|
||||
self._init_persistent_shell()
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-local-{self._session_id}"
|
||||
|
||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
||||
user_shell = _find_bash()
|
||||
run_env = _make_run_env(self.env)
|
||||
return subprocess.Popen(
|
||||
[user_shell, "-l"],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
env=run_env,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
||||
results = []
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
with open(path) as f:
|
||||
results.append(f.read())
|
||||
else:
|
||||
results.append("")
|
||||
return results
|
||||
|
||||
def _kill_shell_children(self):
|
||||
if self._shell_pid is None:
|
||||
return
|
||||
try:
|
||||
subprocess.run(
|
||||
["pkill", "-P", str(self._shell_pid)],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
def _cleanup_temp_files(self):
|
||||
for f in glob.glob(f"{self._temp_prefix}-*"):
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
|
||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
work_dir = cwd or self.cwd or os.getcwd()
|
||||
effective_timeout = timeout or self.timeout
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
# Merge the sudo password (if any) with caller-supplied stdin_data.
|
||||
# sudo -S reads exactly one line (the password) then passes the rest
|
||||
# of stdin to the child, so prepending is safe even when stdin_data
|
||||
# is also present.
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
@@ -231,117 +366,87 @@ class LocalEnvironment(BaseEnvironment):
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
try:
|
||||
# The fence wrapper uses bash syntax (semicolons, $?, printf).
|
||||
# Always use bash for the wrapper — NOT $SHELL which could be
|
||||
# fish, zsh, or another shell with incompatible syntax.
|
||||
# The -lic flags source rc files so tools like nvm/pyenv work.
|
||||
user_shell = _find_bash()
|
||||
# Wrap with output fences so we can later extract the real
|
||||
# command output and discard shell init/exit noise.
|
||||
fenced_cmd = (
|
||||
f"printf '{_OUTPUT_FENCE}';"
|
||||
f" {exec_command};"
|
||||
f" __hermes_rc=$?;"
|
||||
f" printf '{_OUTPUT_FENCE}';"
|
||||
f" exit $__hermes_rc"
|
||||
)
|
||||
# Ensure PATH always includes standard dirs — systemd services
|
||||
# and some terminal multiplexers inherit a minimal PATH.
|
||||
_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
# Strip Hermes-internal provider vars so external CLIs
|
||||
# (e.g. codex) are not silently misrouted. Callers that
|
||||
# truly need a blocked var can opt in by prefixing the key
|
||||
# with _HERMES_FORCE_ in self.env (e.g. _HERMES_FORCE_OPENAI_API_KEY).
|
||||
merged = dict(os.environ | self.env)
|
||||
run_env = {}
|
||||
for k, v in merged.items():
|
||||
if k.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
|
||||
real_key = k[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):]
|
||||
run_env[real_key] = v
|
||||
elif k not in _HERMES_PROVIDER_ENV_BLOCKLIST:
|
||||
run_env[k] = v
|
||||
existing_path = run_env.get("PATH", "")
|
||||
if "/usr/bin" not in existing_path.split(":"):
|
||||
run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH
|
||||
user_shell = _find_bash()
|
||||
fenced_cmd = (
|
||||
f"printf '{_OUTPUT_FENCE}';"
|
||||
f" {exec_command};"
|
||||
f" __hermes_rc=$?;"
|
||||
f" printf '{_OUTPUT_FENCE}';"
|
||||
f" exit $__hermes_rc"
|
||||
)
|
||||
run_env = _make_run_env(self.env)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[user_shell, "-lic", fenced_cmd],
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
env=run_env,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
proc = subprocess.Popen(
|
||||
[user_shell, "-lic", fenced_cmd],
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
env=run_env,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
if effective_stdin is not None:
|
||||
def _write_stdin():
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
threading.Thread(target=_write_stdin, daemon=True).start()
|
||||
|
||||
_output_chunks: list[str] = []
|
||||
|
||||
def _drain_stdout():
|
||||
if effective_stdin is not None:
|
||||
def _write_stdin():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except ValueError:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
proc.stdout.close()
|
||||
except Exception:
|
||||
pass
|
||||
threading.Thread(target=_write_stdin, daemon=True).start()
|
||||
|
||||
reader = threading.Thread(target=_drain_stdout, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
_output_chunks: list[str] = []
|
||||
|
||||
while proc.poll() is None:
|
||||
if _interrupt_event.is_set():
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
def _drain_stdout():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except ValueError:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
proc.stdout.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader.join(timeout=5)
|
||||
output = _extract_fenced_output("".join(_output_chunks))
|
||||
return {"output": output, "returncode": proc.returncode}
|
||||
reader = threading.Thread(target=_drain_stdout, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
except Exception as e:
|
||||
return {"output": f"Execution error: {str(e)}", "returncode": 1}
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
reader.join(timeout=5)
|
||||
output = _extract_fenced_output("".join(_output_chunks))
|
||||
return {"output": output, "returncode": proc.returncode}
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells."""
|
||||
|
||||
import logging
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PersistentShellMixin:
|
||||
"""Mixin that adds persistent shell capability to any BaseEnvironment.
|
||||
|
||||
Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``,
|
||||
``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``.
|
||||
"""
|
||||
|
||||
persistent: bool
|
||||
|
||||
@abstractmethod
|
||||
def _spawn_shell_process(self) -> subprocess.Popen: ...
|
||||
|
||||
@abstractmethod
|
||||
def _read_temp_files(self, *paths: str) -> list[str]: ...
|
||||
|
||||
@abstractmethod
|
||||
def _kill_shell_children(self): ...
|
||||
|
||||
@abstractmethod
|
||||
def _execute_oneshot(self, command: str, cwd: str, *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict: ...
|
||||
|
||||
@abstractmethod
|
||||
def _cleanup_temp_files(self): ...
|
||||
|
||||
_session_id: str = ""
|
||||
_poll_interval: float = 0.01
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-persistent-{self._session_id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_persistent_shell(self):
|
||||
self._shell_lock = threading.Lock()
|
||||
self._shell_proc: subprocess.Popen | None = None
|
||||
self._shell_alive: bool = False
|
||||
self._shell_pid: int | None = None
|
||||
|
||||
self._session_id = uuid.uuid4().hex[:12]
|
||||
p = self._temp_prefix
|
||||
self._pshell_stdout = f"{p}-stdout"
|
||||
self._pshell_stderr = f"{p}-stderr"
|
||||
self._pshell_status = f"{p}-status"
|
||||
self._pshell_cwd = f"{p}-cwd"
|
||||
self._pshell_pid_file = f"{p}-pid"
|
||||
|
||||
self._shell_proc = self._spawn_shell_process()
|
||||
self._shell_alive = True
|
||||
|
||||
self._drain_thread = threading.Thread(
|
||||
target=self._drain_shell_output, daemon=True,
|
||||
)
|
||||
self._drain_thread.start()
|
||||
|
||||
init_script = (
|
||||
f"export TERM=${{TERM:-dumb}}\n"
|
||||
f"touch {self._pshell_stdout} {self._pshell_stderr} "
|
||||
f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n"
|
||||
f"echo $$ > {self._pshell_pid_file}\n"
|
||||
f"pwd > {self._pshell_cwd}\n"
|
||||
)
|
||||
self._send_to_shell(init_script)
|
||||
|
||||
deadline = time.monotonic() + 3.0
|
||||
while time.monotonic() < deadline:
|
||||
pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip()
|
||||
if pid_str.isdigit():
|
||||
self._shell_pid = int(pid_str)
|
||||
break
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
logger.warning("Could not read persistent shell PID")
|
||||
self._shell_pid = None
|
||||
|
||||
if self._shell_pid:
|
||||
logger.info(
|
||||
"Persistent shell started (session=%s, pid=%d)",
|
||||
self._session_id, self._shell_pid,
|
||||
)
|
||||
|
||||
reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip()
|
||||
if reported_cwd:
|
||||
self.cwd = reported_cwd
|
||||
|
||||
def _cleanup_persistent_shell(self):
|
||||
if self._shell_proc is None:
|
||||
return
|
||||
|
||||
if self._session_id:
|
||||
self._cleanup_temp_files()
|
||||
|
||||
try:
|
||||
self._shell_proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._shell_proc.terminate()
|
||||
self._shell_proc.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._shell_proc.kill()
|
||||
|
||||
self._shell_alive = False
|
||||
self._shell_proc = None
|
||||
|
||||
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
|
||||
self._drain_thread.join(timeout=1.0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# execute() / cleanup() — shared dispatcher, subclasses inherit
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if self.persistent:
|
||||
return self._execute_persistent(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
return self._execute_oneshot(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
if self.persistent:
|
||||
self._cleanup_persistent_shell()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shell I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _drain_shell_output(self):
|
||||
try:
|
||||
for _ in self._shell_proc.stdout:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
self._shell_alive = False
|
||||
|
||||
def _send_to_shell(self, text: str):
|
||||
if not self._shell_alive or self._shell_proc is None:
|
||||
return
|
||||
try:
|
||||
self._shell_proc.stdin.write(text)
|
||||
self._shell_proc.stdin.flush()
|
||||
except (BrokenPipeError, OSError):
|
||||
self._shell_alive = False
|
||||
|
||||
def _read_persistent_output(self) -> tuple[str, int, str]:
|
||||
stdout, stderr, status_raw, cwd = self._read_temp_files(
|
||||
self._pshell_stdout, self._pshell_stderr,
|
||||
self._pshell_status, self._pshell_cwd,
|
||||
)
|
||||
output = self._merge_output(stdout, stderr)
|
||||
status = status_raw.strip()
|
||||
if ":" in status:
|
||||
status = status.split(":", 1)[1]
|
||||
try:
|
||||
exit_code = int(status.strip())
|
||||
except ValueError:
|
||||
exit_code = 1
|
||||
return output, exit_code, cwd.strip()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _execute_persistent(self, command: str, cwd: str, *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if not self._shell_alive:
|
||||
logger.info("Persistent shell died, restarting...")
|
||||
self._init_persistent_shell()
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
effective_timeout = timeout or self.timeout
|
||||
if stdin_data or sudo_stdin:
|
||||
return self._execute_oneshot(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
with self._shell_lock:
|
||||
return self._execute_persistent_locked(
|
||||
exec_command, cwd, effective_timeout,
|
||||
)
|
||||
|
||||
def _execute_persistent_locked(self, command: str, cwd: str,
|
||||
timeout: int) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
cmd_id = uuid.uuid4().hex[:8]
|
||||
truncate = (
|
||||
f": > {self._pshell_stdout}\n"
|
||||
f": > {self._pshell_stderr}\n"
|
||||
f": > {self._pshell_status}\n"
|
||||
)
|
||||
self._send_to_shell(truncate)
|
||||
escaped = command.replace("'", "'\\''")
|
||||
|
||||
ipc_script = (
|
||||
f"cd {shlex.quote(work_dir)}\n"
|
||||
f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n"
|
||||
f"__EC=$?\n"
|
||||
f"pwd > {self._pshell_cwd}\n"
|
||||
f"echo {cmd_id}:$__EC > {self._pshell_status}\n"
|
||||
)
|
||||
self._send_to_shell(ipc_script)
|
||||
deadline = time.monotonic() + timeout
|
||||
poll_interval = self._poll_interval
|
||||
|
||||
while True:
|
||||
if is_interrupted():
|
||||
self._kill_shell_children()
|
||||
output, _, _ = self._read_persistent_output()
|
||||
return {
|
||||
"output": output + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
|
||||
if time.monotonic() > deadline:
|
||||
self._kill_shell_children()
|
||||
output, _, _ = self._read_persistent_output()
|
||||
if output:
|
||||
return {
|
||||
"output": output + f"\n[Command timed out after {timeout}s]",
|
||||
"returncode": 124,
|
||||
}
|
||||
return self._timeout_result(timeout)
|
||||
|
||||
if not self._shell_alive:
|
||||
return {
|
||||
"output": "Persistent shell died during execution",
|
||||
"returncode": 1,
|
||||
}
|
||||
|
||||
status_content = self._read_temp_files(self._pshell_status)[0].strip()
|
||||
if status_content.startswith(cmd_id + ":"):
|
||||
break
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
output, exit_code, new_cwd = self._read_persistent_output()
|
||||
if new_cwd:
|
||||
self.cwd = new_cwd
|
||||
return {"output": output, "returncode": exit_code}
|
||||
|
||||
@staticmethod
|
||||
def _merge_output(stdout: str, stderr: str) -> str:
|
||||
parts = []
|
||||
if stdout.strip():
|
||||
parts.append(stdout.rstrip("\n"))
|
||||
if stderr.strip():
|
||||
parts.append(stderr.rstrip("\n"))
|
||||
return "\n".join(parts)
|
||||
+125
-58
@@ -8,12 +8,13 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHEnvironment(BaseEnvironment):
|
||||
class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
"""Run commands on a remote machine over SSH.
|
||||
|
||||
Uses SSH ControlMaster for connection persistence so subsequent
|
||||
@@ -22,22 +23,33 @@ class SSHEnvironment(BaseEnvironment):
|
||||
|
||||
Foreground commands are interruptible: the local ssh process is killed
|
||||
and a remote kill is attempted over the ControlMaster socket.
|
||||
|
||||
When ``persistent=True``, a single long-lived bash shell is kept alive
|
||||
over SSH and state (cwd, env vars, shell variables) persists across
|
||||
``execute()`` calls. Output capture uses file-based IPC on the remote
|
||||
host (stdout/stderr/exit-code written to temp files, polled via fast
|
||||
ControlMaster one-shot reads).
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, user: str, cwd: str = "~",
|
||||
timeout: int = 60, port: int = 22, key_path: str = ""):
|
||||
timeout: int = 60, port: int = 22, key_path: str = "",
|
||||
persistent: bool = False):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self.host = host
|
||||
self.user = user
|
||||
self.port = port
|
||||
self.key_path = key_path
|
||||
self.persistent = persistent
|
||||
|
||||
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
|
||||
self.control_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock"
|
||||
self._establish_connection()
|
||||
|
||||
def _build_ssh_command(self, extra_args: list = None) -> list:
|
||||
if self.persistent:
|
||||
self._init_persistent_shell()
|
||||
|
||||
def _build_ssh_command(self, extra_args: list | None = None) -> list:
|
||||
cmd = ["ssh"]
|
||||
cmd.extend(["-o", f"ControlPath={self.control_socket}"])
|
||||
cmd.extend(["-o", "ControlMaster=auto"])
|
||||
@@ -65,15 +77,76 @@ class SSHEnvironment(BaseEnvironment):
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
_poll_interval: float = 0.15
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-ssh-{self._session_id}"
|
||||
|
||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append("bash -l")
|
||||
return subprocess.Popen(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
||||
if len(paths) == 1:
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"cat {paths[0]} 2>/dev/null")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
return [result.stdout]
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return [""]
|
||||
|
||||
delim = f"__HERMES_SEP_{self._session_id}__"
|
||||
script = "; ".join(
|
||||
f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths
|
||||
)
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(script)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
parts = result.stdout.split(delim + "\n")
|
||||
return [parts[i] if i < len(parts) else "" for i in range(len(paths))]
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return [""] * len(paths)
|
||||
|
||||
def _kill_shell_children(self):
|
||||
if self._shell_pid is None:
|
||||
return
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true")
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
def _cleanup_temp_files(self):
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"rm -f {self._temp_prefix}-*")
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
wrapped = f'cd {work_dir} && {exec_command}'
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
# Merge sudo password (if any) with caller-supplied stdin_data.
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
@@ -82,66 +155,60 @@ class SSHEnvironment(BaseEnvironment):
|
||||
effective_stdin = stdin_data
|
||||
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.extend(["bash", "-c", wrapped])
|
||||
cmd.append(wrapped)
|
||||
|
||||
try:
|
||||
kwargs = self._build_run_kwargs(timeout, effective_stdin)
|
||||
# Remove timeout from kwargs -- we handle it in the poll loop
|
||||
kwargs.pop("timeout", None)
|
||||
kwargs = self._build_run_kwargs(timeout, effective_stdin)
|
||||
kwargs.pop("timeout", None)
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
||||
_output_chunks = []
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if effective_stdin:
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
|
||||
except Exception as e:
|
||||
return {"output": f"SSH execution error: {str(e)}", "returncode": 1}
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
if self.control_socket.exists():
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||
|
||||
@@ -101,12 +101,31 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
||||
"container_persistent": config.get("container_persistent", True),
|
||||
"docker_volumes": config.get("docker_volumes", []),
|
||||
}
|
||||
|
||||
ssh_config = None
|
||||
if env_type == "ssh":
|
||||
ssh_config = {
|
||||
"host": config.get("ssh_host", ""),
|
||||
"user": config.get("ssh_user", ""),
|
||||
"port": config.get("ssh_port", 22),
|
||||
"key": config.get("ssh_key", ""),
|
||||
"persistent": config.get("ssh_persistent", False),
|
||||
}
|
||||
|
||||
local_config = None
|
||||
if env_type == "local":
|
||||
local_config = {
|
||||
"persistent": config.get("local_persistent", False),
|
||||
}
|
||||
|
||||
terminal_env = _create_environment(
|
||||
env_type=env_type,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
timeout=config["timeout"],
|
||||
ssh_config=ssh_config,
|
||||
container_config=container_config,
|
||||
local_config=local_config,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ import time
|
||||
import uuid
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from tools.environments.local import _find_shell, _HERMES_PROVIDER_ENV_BLOCKLIST
|
||||
from tools.environments.local import _find_shell, _sanitize_subprocess_env
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -155,9 +155,7 @@ class ProcessRegistry:
|
||||
else:
|
||||
from ptyprocess import PtyProcess as _PtyProcessCls
|
||||
user_shell = _find_shell()
|
||||
pty_env = {k: v for k, v in os.environ.items()
|
||||
if k not in _HERMES_PROVIDER_ENV_BLOCKLIST}
|
||||
pty_env.update(env_vars or {})
|
||||
pty_env = _sanitize_subprocess_env(os.environ, env_vars)
|
||||
pty_env["PYTHONUNBUFFERED"] = "1"
|
||||
pty_proc = _PtyProcessCls.spawn(
|
||||
[user_shell, "-lic", command],
|
||||
@@ -198,9 +196,7 @@ class ProcessRegistry:
|
||||
# Force unbuffered output for Python scripts so progress is visible
|
||||
# during background execution (libraries like tqdm/datasets buffer when
|
||||
# stdout is a pipe, hiding output from process(action="poll")).
|
||||
bg_env = {k: v for k, v in os.environ.items()
|
||||
if k not in _HERMES_PROVIDER_ENV_BLOCKLIST}
|
||||
bg_env.update(env_vars or {})
|
||||
bg_env = _sanitize_subprocess_env(os.environ, env_vars)
|
||||
bg_env["PYTHONUNBUFFERED"] = "1"
|
||||
proc = subprocess.Popen(
|
||||
[user_shell, "-lic", command],
|
||||
|
||||
+241
-13
@@ -1156,11 +1156,176 @@ class ClawHubSource(SkillSource):
|
||||
def trust_level_for(self, identifier: str) -> str:
|
||||
return "community"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tags(tags: Any) -> List[str]:
|
||||
if isinstance(tags, list):
|
||||
return [str(t) for t in tags]
|
||||
if isinstance(tags, dict):
|
||||
return [str(k) for k in tags.keys() if str(k) != "latest"]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _coerce_skill_payload(data: Any) -> Optional[Dict[str, Any]]:
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
nested = data.get("skill")
|
||||
if isinstance(nested, dict):
|
||||
merged = dict(nested)
|
||||
latest_version = data.get("latestVersion")
|
||||
if latest_version is not None and "latestVersion" not in merged:
|
||||
merged["latestVersion"] = latest_version
|
||||
return merged
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _query_terms(query: str) -> List[str]:
|
||||
return [term for term in re.split(r"[^a-z0-9]+", query.lower()) if term]
|
||||
|
||||
@classmethod
|
||||
def _search_score(cls, query: str, meta: SkillMeta) -> int:
|
||||
query_norm = query.strip().lower()
|
||||
if not query_norm:
|
||||
return 1
|
||||
|
||||
identifier = (meta.identifier or "").lower()
|
||||
name = (meta.name or "").lower()
|
||||
description = (meta.description or "").lower()
|
||||
normalized_identifier = " ".join(cls._query_terms(identifier))
|
||||
normalized_name = " ".join(cls._query_terms(name))
|
||||
query_terms = cls._query_terms(query_norm)
|
||||
identifier_terms = cls._query_terms(identifier)
|
||||
name_terms = cls._query_terms(name)
|
||||
score = 0
|
||||
|
||||
if query_norm == identifier:
|
||||
score += 140
|
||||
if query_norm == name:
|
||||
score += 130
|
||||
if normalized_identifier == query_norm:
|
||||
score += 125
|
||||
if normalized_name == query_norm:
|
||||
score += 120
|
||||
if normalized_identifier.startswith(query_norm):
|
||||
score += 95
|
||||
if normalized_name.startswith(query_norm):
|
||||
score += 90
|
||||
if query_terms and identifier_terms[: len(query_terms)] == query_terms:
|
||||
score += 70
|
||||
if query_terms and name_terms[: len(query_terms)] == query_terms:
|
||||
score += 65
|
||||
if query_norm in identifier:
|
||||
score += 40
|
||||
if query_norm in name:
|
||||
score += 35
|
||||
if query_norm in description:
|
||||
score += 10
|
||||
|
||||
for term in query_terms:
|
||||
if term in identifier_terms:
|
||||
score += 15
|
||||
if term in name_terms:
|
||||
score += 12
|
||||
if term in description:
|
||||
score += 3
|
||||
|
||||
return score
|
||||
|
||||
@staticmethod
|
||||
def _dedupe_results(results: List[SkillMeta]) -> List[SkillMeta]:
|
||||
seen: set[str] = set()
|
||||
deduped: List[SkillMeta] = []
|
||||
for result in results:
|
||||
key = (result.identifier or result.name).lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(result)
|
||||
return deduped
|
||||
|
||||
def _exact_slug_meta(self, query: str) -> Optional[SkillMeta]:
|
||||
slug = query.strip().split("/")[-1]
|
||||
query_terms = self._query_terms(query)
|
||||
candidates: List[str] = []
|
||||
|
||||
if slug and re.fullmatch(r"[A-Za-z0-9][A-Za-z0-9._-]*", slug):
|
||||
candidates.append(slug)
|
||||
|
||||
if query_terms:
|
||||
base_slug = "-".join(query_terms)
|
||||
if len(query_terms) >= 2:
|
||||
candidates.extend([
|
||||
f"{base_slug}-agent",
|
||||
f"{base_slug}-skill",
|
||||
f"{base_slug}-tool",
|
||||
f"{base_slug}-assistant",
|
||||
f"{base_slug}-playbook",
|
||||
base_slug,
|
||||
])
|
||||
else:
|
||||
candidates.append(base_slug)
|
||||
|
||||
seen: set[str] = set()
|
||||
for candidate in candidates:
|
||||
if candidate in seen:
|
||||
continue
|
||||
seen.add(candidate)
|
||||
meta = self.inspect(candidate)
|
||||
if meta:
|
||||
return meta
|
||||
|
||||
return None
|
||||
|
||||
def _finalize_search_results(self, query: str, results: List[SkillMeta], limit: int) -> List[SkillMeta]:
|
||||
query_norm = query.strip()
|
||||
if not query_norm:
|
||||
return self._dedupe_results(results)[:limit]
|
||||
|
||||
filtered = [meta for meta in results if self._search_score(query_norm, meta) > 0]
|
||||
filtered.sort(
|
||||
key=lambda meta: (
|
||||
-self._search_score(query_norm, meta),
|
||||
meta.name.lower(),
|
||||
meta.identifier.lower(),
|
||||
)
|
||||
)
|
||||
filtered = self._dedupe_results(filtered)
|
||||
|
||||
exact = self._exact_slug_meta(query_norm)
|
||||
if exact:
|
||||
filtered = [meta for meta in filtered if self._search_score(query_norm, meta) >= 20]
|
||||
filtered = self._dedupe_results([exact] + filtered)
|
||||
|
||||
if filtered:
|
||||
return filtered[:limit]
|
||||
|
||||
if re.fullmatch(r"[A-Za-z0-9][A-Za-z0-9._/-]*", query_norm):
|
||||
return []
|
||||
|
||||
return self._dedupe_results(results)[:limit]
|
||||
|
||||
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
|
||||
cache_key = f"clawhub_search_{hashlib.md5(query.encode()).hexdigest()}"
|
||||
query = query.strip()
|
||||
|
||||
if query:
|
||||
query_terms = self._query_terms(query)
|
||||
if len(query_terms) >= 2:
|
||||
direct = self._exact_slug_meta(query)
|
||||
if direct:
|
||||
return [direct]
|
||||
|
||||
results = self._search_catalog(query, limit=limit)
|
||||
if results:
|
||||
return results
|
||||
|
||||
# Empty query or catalog fallback failure: use the lightweight listing API.
|
||||
cache_key = f"clawhub_search_listing_v1_{hashlib.md5(query.encode()).hexdigest()}_{limit}"
|
||||
cached = _read_index_cache(cache_key)
|
||||
if cached is not None:
|
||||
return [SkillMeta(**s) for s in cached][:limit]
|
||||
return self._finalize_search_results(
|
||||
query,
|
||||
[SkillMeta(**s) for s in cached],
|
||||
limit,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = httpx.get(
|
||||
@@ -1185,20 +1350,19 @@ class ClawHubSource(SkillSource):
|
||||
continue
|
||||
display_name = item.get("displayName") or item.get("name") or slug
|
||||
summary = item.get("summary") or item.get("description") or ""
|
||||
tags = item.get("tags", [])
|
||||
if not isinstance(tags, list):
|
||||
tags = []
|
||||
tags = self._normalize_tags(item.get("tags", []))
|
||||
results.append(SkillMeta(
|
||||
name=display_name,
|
||||
description=summary,
|
||||
source="clawhub",
|
||||
identifier=slug,
|
||||
trust_level="community",
|
||||
tags=[str(t) for t in tags],
|
||||
tags=tags,
|
||||
))
|
||||
|
||||
_write_index_cache(cache_key, [_skill_meta_to_dict(s) for s in results])
|
||||
return results
|
||||
final_results = self._finalize_search_results(query, results, limit)
|
||||
_write_index_cache(cache_key, [_skill_meta_to_dict(s) for s in final_results])
|
||||
return final_results
|
||||
|
||||
def fetch(self, identifier: str) -> Optional[SkillBundle]:
|
||||
slug = identifier.split("/")[-1]
|
||||
@@ -1244,13 +1408,11 @@ class ClawHubSource(SkillSource):
|
||||
|
||||
def inspect(self, identifier: str) -> Optional[SkillMeta]:
|
||||
slug = identifier.split("/")[-1]
|
||||
data = self._get_json(f"{self.BASE_URL}/skills/{slug}")
|
||||
data = self._coerce_skill_payload(self._get_json(f"{self.BASE_URL}/skills/{slug}"))
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
tags = data.get("tags", [])
|
||||
if not isinstance(tags, list):
|
||||
tags = []
|
||||
tags = self._normalize_tags(data.get("tags", []))
|
||||
|
||||
return SkillMeta(
|
||||
name=data.get("displayName") or data.get("name") or data.get("slug") or slug,
|
||||
@@ -1258,9 +1420,75 @@ class ClawHubSource(SkillSource):
|
||||
source="clawhub",
|
||||
identifier=data.get("slug") or slug,
|
||||
trust_level="community",
|
||||
tags=[str(t) for t in tags],
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
def _search_catalog(self, query: str, limit: int = 10) -> List[SkillMeta]:
|
||||
cache_key = f"clawhub_search_catalog_v1_{hashlib.md5(f'{query}|{limit}'.encode()).hexdigest()}"
|
||||
cached = _read_index_cache(cache_key)
|
||||
if cached is not None:
|
||||
return [SkillMeta(**s) for s in cached][:limit]
|
||||
|
||||
catalog = self._load_catalog_index()
|
||||
if not catalog:
|
||||
return []
|
||||
|
||||
results = self._finalize_search_results(query, catalog, limit)
|
||||
_write_index_cache(cache_key, [_skill_meta_to_dict(s) for s in results])
|
||||
return results
|
||||
|
||||
def _load_catalog_index(self) -> List[SkillMeta]:
|
||||
cache_key = "clawhub_catalog_v1"
|
||||
cached = _read_index_cache(cache_key)
|
||||
if cached is not None:
|
||||
return [SkillMeta(**s) for s in cached]
|
||||
|
||||
cursor: Optional[str] = None
|
||||
results: List[SkillMeta] = []
|
||||
seen: set[str] = set()
|
||||
max_pages = 50
|
||||
|
||||
for _ in range(max_pages):
|
||||
params: Dict[str, Any] = {"limit": 200}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
try:
|
||||
resp = httpx.get(f"{self.BASE_URL}/skills", params=params, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
break
|
||||
data = resp.json()
|
||||
except (httpx.HTTPError, json.JSONDecodeError):
|
||||
break
|
||||
|
||||
items = data.get("items", []) if isinstance(data, dict) else []
|
||||
if not isinstance(items, list) or not items:
|
||||
break
|
||||
|
||||
for item in items:
|
||||
slug = item.get("slug")
|
||||
if not isinstance(slug, str) or not slug or slug in seen:
|
||||
continue
|
||||
seen.add(slug)
|
||||
display_name = item.get("displayName") or item.get("name") or slug
|
||||
summary = item.get("summary") or item.get("description") or ""
|
||||
tags = self._normalize_tags(item.get("tags", []))
|
||||
results.append(SkillMeta(
|
||||
name=display_name,
|
||||
description=summary,
|
||||
source="clawhub",
|
||||
identifier=slug,
|
||||
trust_level="community",
|
||||
tags=tags,
|
||||
))
|
||||
|
||||
cursor = data.get("nextCursor") if isinstance(data, dict) else None
|
||||
if not isinstance(cursor, str) or not cursor:
|
||||
break
|
||||
|
||||
_write_index_cache(cache_key, [_skill_meta_to_dict(s) for s in results])
|
||||
return results
|
||||
|
||||
def _get_json(self, url: str, timeout: int = 20) -> Optional[Any]:
|
||||
try:
|
||||
resp = httpx.get(url, timeout=timeout)
|
||||
|
||||
+17
-1
@@ -471,6 +471,8 @@ def _get_env_config() -> Dict[str, Any]:
|
||||
# is running inside the container/remote).
|
||||
if env_type == "local":
|
||||
default_cwd = os.getcwd()
|
||||
elif env_type == "ssh":
|
||||
default_cwd = "~"
|
||||
else:
|
||||
default_cwd = "/root"
|
||||
|
||||
@@ -503,6 +505,8 @@ def _get_env_config() -> Dict[str, Any]:
|
||||
"ssh_user": os.getenv("TERMINAL_SSH_USER", ""),
|
||||
"ssh_port": _parse_env_var("TERMINAL_SSH_PORT", "22"),
|
||||
"ssh_key": os.getenv("TERMINAL_SSH_KEY", ""),
|
||||
"ssh_persistent": os.getenv("TERMINAL_SSH_PERSISTENT", "false").lower() in ("true", "1", "yes"),
|
||||
"local_persistent": os.getenv("TERMINAL_LOCAL_PERSISTENT", "false").lower() in ("true", "1", "yes"),
|
||||
# Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh)
|
||||
"container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"),
|
||||
"container_memory": _parse_env_var("TERMINAL_CONTAINER_MEMORY", "5120"), # MB (default 5GB)
|
||||
@@ -514,6 +518,7 @@ def _get_env_config() -> Dict[str, Any]:
|
||||
|
||||
def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
||||
ssh_config: dict = None, container_config: dict = None,
|
||||
local_config: dict = None,
|
||||
task_id: str = "default"):
|
||||
"""
|
||||
Create an execution environment from mini-swe-agent.
|
||||
@@ -538,7 +543,9 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
||||
volumes = cc.get("docker_volumes", [])
|
||||
|
||||
if env_type == "local":
|
||||
return _LocalEnvironment(cwd=cwd, timeout=timeout)
|
||||
lc = local_config or {}
|
||||
return _LocalEnvironment(cwd=cwd, timeout=timeout,
|
||||
persistent=lc.get("persistent", False))
|
||||
|
||||
elif env_type == "docker":
|
||||
return _DockerEnvironment(
|
||||
@@ -594,6 +601,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
||||
key_path=ssh_config.get("key", ""),
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
persistent=ssh_config.get("persistent", False),
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -923,6 +931,7 @@ def terminal_tool(
|
||||
"user": config.get("ssh_user", ""),
|
||||
"port": config.get("ssh_port", 22),
|
||||
"key": config.get("ssh_key", ""),
|
||||
"persistent": config.get("ssh_persistent", False),
|
||||
}
|
||||
|
||||
container_config = None
|
||||
@@ -935,6 +944,12 @@ def terminal_tool(
|
||||
"docker_volumes": config.get("docker_volumes", []),
|
||||
}
|
||||
|
||||
local_config = None
|
||||
if env_type == "local":
|
||||
local_config = {
|
||||
"persistent": config.get("local_persistent", False),
|
||||
}
|
||||
|
||||
new_env = _create_environment(
|
||||
env_type=env_type,
|
||||
image=image,
|
||||
@@ -942,6 +957,7 @@ def terminal_tool(
|
||||
timeout=effective_timeout,
|
||||
ssh_config=ssh_config,
|
||||
container_config=container_config,
|
||||
local_config=local_config,
|
||||
task_id=effective_task_id,
|
||||
)
|
||||
except ImportError as e:
|
||||
|
||||
@@ -130,7 +130,41 @@ When an auxiliary task is configured with provider `main`, Hermes resolves that
|
||||
|
||||
## Fallback models
|
||||
|
||||
Hermes also supports a configured fallback model/provider, allowing runtime failover in supported error paths.
|
||||
Hermes supports a configured fallback model/provider pair, allowing runtime failover when the primary model encounters errors.
|
||||
|
||||
### How it works internally
|
||||
|
||||
1. **Storage**: `AIAgent.__init__` stores the `fallback_model` dict and sets `_fallback_activated = False`.
|
||||
|
||||
2. **Trigger points**: `_try_activate_fallback()` is called from three places in the main retry loop in `run_agent.py`:
|
||||
- After max retries on invalid API responses (None choices, missing content)
|
||||
- On non-retryable client errors (HTTP 401, 403, 404)
|
||||
- After max retries on transient errors (HTTP 429, 500, 502, 503)
|
||||
|
||||
3. **Activation flow** (`_try_activate_fallback`):
|
||||
- Returns `False` immediately if already activated or not configured
|
||||
- Calls `resolve_provider_client()` from `auxiliary_client.py` to build a new client with proper auth
|
||||
- Determines `api_mode`: `codex_responses` for openai-codex, `anthropic_messages` for anthropic, `chat_completions` for everything else
|
||||
- Swaps in-place: `self.model`, `self.provider`, `self.base_url`, `self.api_mode`, `self.client`, `self._client_kwargs`
|
||||
- For anthropic fallback: builds a native Anthropic client instead of OpenAI-compatible
|
||||
- Re-evaluates prompt caching (enabled for Claude models on OpenRouter)
|
||||
- Sets `_fallback_activated = True` — prevents firing again
|
||||
- Resets retry count to 0 and continues the loop
|
||||
|
||||
4. **Config flow**:
|
||||
- CLI: `cli.py` reads `CLI_CONFIG["fallback_model"]` → passes to `AIAgent(fallback_model=...)`
|
||||
- Gateway: `gateway/run.py._load_fallback_model()` reads `config.yaml` → passes to `AIAgent`
|
||||
- Validation: both `provider` and `model` keys must be non-empty, or fallback is disabled
|
||||
|
||||
### What does NOT support fallback
|
||||
|
||||
- **Subagent delegation** (`tools/delegate_tool.py`): subagents inherit the parent's provider but not the fallback config
|
||||
- **Cron jobs** (`cron/`): run with a fixed provider, no fallback mechanism
|
||||
- **Auxiliary tasks**: use their own independent provider auto-detection chain (see Auxiliary model routing above)
|
||||
|
||||
### Test coverage
|
||||
|
||||
See `tests/test_fallback_model.py` for comprehensive tests covering all supported providers, one-shot semantics, and edge cases.
|
||||
|
||||
## Related docs
|
||||
|
||||
|
||||
@@ -164,6 +164,7 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
|
||||
| `HERMES_QUIET` | Suppress non-essential output (`true`/`false`) |
|
||||
| `HERMES_API_TIMEOUT` | LLM API call timeout in seconds (default: `900`) |
|
||||
| `HERMES_EXEC_ASK` | Enable execution approval prompts in gateway mode (`true`/`false`) |
|
||||
| `HERMES_BACKGROUND_NOTIFICATIONS` | Background process notification mode in gateway: `all` (default), `result`, `error`, `off` |
|
||||
|
||||
## Session Settings
|
||||
|
||||
@@ -197,6 +198,18 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
|
||||
|
||||
For task-specific direct endpoints, Hermes uses the task's configured API key or `OPENAI_API_KEY`. It does not reuse `OPENROUTER_API_KEY` for those custom endpoints.
|
||||
|
||||
## Fallback Model (config.yaml only)
|
||||
|
||||
The primary model fallback is configured exclusively through `config.yaml` — there are no environment variables for it. Add a `fallback_model` section with `provider` and `model` keys to enable automatic failover when your main model encounters errors.
|
||||
|
||||
```yaml
|
||||
fallback_model:
|
||||
provider: openrouter
|
||||
model: anthropic/claude-sonnet-4
|
||||
```
|
||||
|
||||
See [Fallback Providers](/docs/user-guide/features/fallback-providers) for full details.
|
||||
|
||||
## Provider Routing (config.yaml only)
|
||||
|
||||
These go in `~/.hermes/config.yaml` under the `provider_routing` section:
|
||||
|
||||
@@ -31,7 +31,7 @@ Type `/` in the CLI to open the autocomplete menu. Built-in commands are case-in
|
||||
| `/title` | Set a title for the current session (usage: /title My Session Name) |
|
||||
| `/compress` | Manually compress conversation context (flush memories + summarize) |
|
||||
| `/rollback` | List or restore filesystem checkpoints (usage: /rollback [number]) |
|
||||
| `/background` | Run a prompt in the background (usage: /background <prompt>) |
|
||||
| `/background <prompt>` | Run a prompt in a separate background session. The agent processes your prompt independently — your current session stays free for other work. Results appear as a panel when the task finishes. See [CLI Background Sessions](/docs/user-guide/cli#background-sessions). |
|
||||
| `/plan [request]` | Load the bundled `plan` skill to write a markdown plan instead of executing the work. Plans are saved under `.hermes/plans/` relative to the active workspace/backend working directory. |
|
||||
|
||||
### Configuration
|
||||
@@ -109,7 +109,7 @@ The messaging gateway supports the following built-in commands inside Telegram,
|
||||
| `/reasoning [level\|show\|hide]` | Change reasoning effort or toggle reasoning display. |
|
||||
| `/voice [on\|off\|tts\|join\|channel\|leave\|status]` | Control spoken replies in chat. `join`/`channel`/`leave` manage Discord voice-channel mode. |
|
||||
| `/rollback [number]` | List or restore filesystem checkpoints. |
|
||||
| `/background <prompt>` | Run a prompt in a separate background session. |
|
||||
| `/background <prompt>` | Run a prompt in a separate background session. Results are delivered back to the same chat when the task finishes. See [Messaging Background Sessions](/docs/user-guide/messaging/#background-sessions). |
|
||||
| `/plan [request]` | Load the bundled `plan` skill to write a markdown plan instead of executing the work. Plans are saved under `.hermes/plans/` relative to the active workspace/backend working directory. |
|
||||
| `/reload-mcp` | Reload MCP servers from config. |
|
||||
| `/update` | Update Hermes Agent to the latest version. |
|
||||
@@ -119,6 +119,6 @@ The messaging gateway supports the following built-in commands inside Telegram,
|
||||
## Notes
|
||||
|
||||
- `/skin`, `/tools`, `/toolsets`, `/config`, `/prompt`, `/cron`, `/skills`, `/platforms`, `/paste`, and `/verbose` are **CLI-only** commands.
|
||||
- `/status`, `/stop`, `/sethome`, `/resume`, `/background`, and `/update` are **messaging-only** commands.
|
||||
- `/voice`, `/reload-mcp`, and `/rollback` work in **both** the CLI and the messaging gateway.
|
||||
- `/status`, `/stop`, `/sethome`, `/resume`, and `/update` are **messaging-only** commands.
|
||||
- `/background`, `/voice`, `/reload-mcp`, and `/rollback` work in **both** the CLI and the messaging gateway.
|
||||
- `/voice join`, `/voice channel`, and `/voice leave` are only meaningful on Discord.
|
||||
|
||||
@@ -259,6 +259,55 @@ compression:
|
||||
|
||||
When compression triggers, middle turns are summarized while the first 3 and last 4 turns are always preserved.
|
||||
|
||||
## Background Sessions
|
||||
|
||||
Run a prompt in a separate background session while continuing to use the CLI for other work:
|
||||
|
||||
```
|
||||
/background Analyze the logs in /var/log and summarize any errors from today
|
||||
```
|
||||
|
||||
Hermes immediately confirms the task and gives you back the prompt:
|
||||
|
||||
```
|
||||
🔄 Background task #1 started: "Analyze the logs in /var/log and summarize..."
|
||||
Task ID: bg_143022_a1b2c3
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
Each `/background` prompt spawns a **completely separate agent session** in a daemon thread:
|
||||
|
||||
- **Isolated conversation** — the background agent has no knowledge of your current session's history. It receives only the prompt you provide.
|
||||
- **Same configuration** — the background agent inherits your model, provider, toolsets, reasoning settings, and fallback model from the current session.
|
||||
- **Non-blocking** — your foreground session stays fully interactive. You can chat, run commands, or even start more background tasks.
|
||||
- **Multiple tasks** — you can run several background tasks simultaneously. Each gets a numbered ID.
|
||||
|
||||
### Results
|
||||
|
||||
When a background task finishes, the result appears as a panel in your terminal:
|
||||
|
||||
```
|
||||
╭─ ⚕ Hermes (background #1) ──────────────────────────────────╮
|
||||
│ Found 3 errors in syslog from today: │
|
||||
│ 1. OOM killer invoked at 03:22 — killed process nginx │
|
||||
│ 2. Disk I/O error on /dev/sda1 at 07:15 │
|
||||
│ 3. Failed SSH login attempts from 192.168.1.50 at 14:30 │
|
||||
╰──────────────────────────────────────────────────────────────╯
|
||||
```
|
||||
|
||||
If the task fails, you'll see an error notification instead. If `display.bell_on_complete` is enabled in your config, the terminal bell rings when the task finishes.
|
||||
|
||||
### Use Cases
|
||||
|
||||
- **Long-running research** — "/background research the latest developments in quantum error correction" while you work on code
|
||||
- **File processing** — "/background analyze all Python files in this repo and list any security issues" while you continue a conversation
|
||||
- **Parallel investigations** — start multiple background tasks to explore different angles simultaneously
|
||||
|
||||
:::info
|
||||
Background sessions do not appear in your main conversation history. They are standalone sessions with their own task ID (e.g., `bg_143022_a1b2c3`).
|
||||
:::
|
||||
|
||||
## Quiet Mode
|
||||
|
||||
By default, the CLI runs in quiet mode which:
|
||||
|
||||
@@ -421,6 +421,26 @@ provider_routing:
|
||||
|
||||
**Shortcuts:** Append `:nitro` to any model name for throughput sorting (e.g., `anthropic/claude-sonnet-4:nitro`), or `:floor` for price sorting.
|
||||
|
||||
## Fallback Model
|
||||
|
||||
Configure a backup provider:model that Hermes switches to automatically when your primary model fails (rate limits, server errors, auth failures):
|
||||
|
||||
```yaml
|
||||
fallback_model:
|
||||
provider: openrouter # required
|
||||
model: anthropic/claude-sonnet-4 # required
|
||||
# base_url: http://localhost:8000/v1 # optional, for custom endpoints
|
||||
# api_key_env: MY_CUSTOM_KEY # optional, env var name for custom endpoint API key
|
||||
```
|
||||
|
||||
When activated, the fallback swaps the model and provider mid-session without losing your conversation. It fires **at most once** per session.
|
||||
|
||||
Supported providers: `openrouter`, `nous`, `openai-codex`, `anthropic`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `custom`.
|
||||
|
||||
:::tip
|
||||
Fallback is configured exclusively through `config.yaml` — there are no environment variables for it. For full details on when it triggers, supported providers, and how it interacts with auxiliary tasks and delegation, see [Fallback Providers](/docs/user-guide/features/fallback-providers).
|
||||
:::
|
||||
|
||||
## Terminal Backend Configuration
|
||||
|
||||
Configure which environment the agent uses for terminal commands:
|
||||
@@ -733,6 +753,7 @@ display:
|
||||
resume_display: full # full (show previous messages on resume) | minimal (one-liner only)
|
||||
bell_on_complete: false # Play terminal bell when agent finishes (great for long tasks)
|
||||
show_reasoning: false # Show model reasoning/thinking above each response (toggle with /reasoning show|hide)
|
||||
background_process_notifications: all # all | result | error | off (gateway only)
|
||||
```
|
||||
|
||||
| Mode | What you see |
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
---
|
||||
title: Fallback Providers
|
||||
description: Configure automatic failover to backup LLM providers when your primary model is unavailable.
|
||||
sidebar_label: Fallback Providers
|
||||
sidebar_position: 8
|
||||
---
|
||||
|
||||
# Fallback Providers
|
||||
|
||||
Hermes Agent has two separate fallback systems that keep your sessions running when providers hit issues:
|
||||
|
||||
1. **Primary model fallback** — automatically switches to a backup provider:model when your main model fails
|
||||
2. **Auxiliary task fallback** — independent provider resolution for side tasks like vision, compression, and web extraction
|
||||
|
||||
Both are optional and work independently.
|
||||
|
||||
## Primary Model Fallback
|
||||
|
||||
When your main LLM provider encounters errors — rate limits, server overload, auth failures, connection drops — Hermes can automatically switch to a backup provider:model pair mid-session without losing your conversation.
|
||||
|
||||
### Configuration
|
||||
|
||||
Add a `fallback_model` section to `~/.hermes/config.yaml`:
|
||||
|
||||
```yaml
|
||||
fallback_model:
|
||||
provider: openrouter
|
||||
model: anthropic/claude-sonnet-4
|
||||
```
|
||||
|
||||
Both `provider` and `model` are **required**. If either is missing, the fallback is disabled.
|
||||
|
||||
### Supported Providers
|
||||
|
||||
| Provider | Value | Requirements |
|
||||
|----------|-------|-------------|
|
||||
| OpenRouter | `openrouter` | `OPENROUTER_API_KEY` |
|
||||
| Nous Portal | `nous` | `hermes login` (OAuth) |
|
||||
| OpenAI Codex | `openai-codex` | `hermes model` (ChatGPT OAuth) |
|
||||
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` or Claude Code credentials |
|
||||
| z.ai / GLM | `zai` | `GLM_API_KEY` |
|
||||
| Kimi / Moonshot | `kimi-coding` | `KIMI_API_KEY` |
|
||||
| MiniMax | `minimax` | `MINIMAX_API_KEY` |
|
||||
| MiniMax (China) | `minimax-cn` | `MINIMAX_CN_API_KEY` |
|
||||
| Custom endpoint | `custom` | `base_url` + `api_key_env` (see below) |
|
||||
|
||||
### Custom Endpoint Fallback
|
||||
|
||||
For a custom OpenAI-compatible endpoint, add `base_url` and optionally `api_key_env`:
|
||||
|
||||
```yaml
|
||||
fallback_model:
|
||||
provider: custom
|
||||
model: my-local-model
|
||||
base_url: http://localhost:8000/v1
|
||||
api_key_env: MY_LOCAL_KEY # env var name containing the API key
|
||||
```
|
||||
|
||||
### When Fallback Triggers
|
||||
|
||||
The fallback activates automatically when the primary model fails with:
|
||||
|
||||
- **Rate limits** (HTTP 429) — after exhausting retry attempts
|
||||
- **Server errors** (HTTP 500, 502, 503) — after exhausting retry attempts
|
||||
- **Auth failures** (HTTP 401, 403) — immediately (no point retrying)
|
||||
- **Not found** (HTTP 404) — immediately
|
||||
- **Invalid responses** — when the API returns malformed or empty responses repeatedly
|
||||
|
||||
When triggered, Hermes:
|
||||
|
||||
1. Resolves credentials for the fallback provider
|
||||
2. Builds a new API client
|
||||
3. Swaps the model, provider, and client in-place
|
||||
4. Resets the retry counter and continues the conversation
|
||||
|
||||
The switch is seamless — your conversation history, tool calls, and context are preserved. The agent continues from exactly where it left off, just using a different model.
|
||||
|
||||
:::info One-Shot
|
||||
Fallback activates **at most once** per session. If the fallback provider also fails, normal error handling takes over (retries, then error message). This prevents cascading failover loops.
|
||||
:::
|
||||
|
||||
### Examples
|
||||
|
||||
**OpenRouter as fallback for Anthropic native:**
|
||||
```yaml
|
||||
model:
|
||||
provider: anthropic
|
||||
default: claude-sonnet-4-6
|
||||
|
||||
fallback_model:
|
||||
provider: openrouter
|
||||
model: anthropic/claude-sonnet-4
|
||||
```
|
||||
|
||||
**Nous Portal as fallback for OpenRouter:**
|
||||
```yaml
|
||||
model:
|
||||
provider: openrouter
|
||||
default: anthropic/claude-opus-4
|
||||
|
||||
fallback_model:
|
||||
provider: nous
|
||||
model: nous-hermes-3
|
||||
```
|
||||
|
||||
**Local model as fallback for cloud:**
|
||||
```yaml
|
||||
fallback_model:
|
||||
provider: custom
|
||||
model: llama-3.1-70b
|
||||
base_url: http://localhost:8000/v1
|
||||
api_key_env: LOCAL_API_KEY
|
||||
```
|
||||
|
||||
**Codex OAuth as fallback:**
|
||||
```yaml
|
||||
fallback_model:
|
||||
provider: openai-codex
|
||||
model: gpt-5.3-codex
|
||||
```
|
||||
|
||||
### Where Fallback Works
|
||||
|
||||
| Context | Fallback Supported |
|
||||
|---------|-------------------|
|
||||
| CLI sessions | ✔ |
|
||||
| Messaging gateway (Telegram, Discord, etc.) | ✔ |
|
||||
| Subagent delegation | ✘ (subagents do not inherit fallback config) |
|
||||
| Cron jobs | ✘ (run with a fixed provider) |
|
||||
| Auxiliary tasks (vision, compression) | ✘ (use their own provider chain — see below) |
|
||||
|
||||
:::tip
|
||||
There are no environment variables for `fallback_model` — it is configured exclusively through `config.yaml`. This is intentional: fallback configuration is a deliberate choice, not something a stale shell export should override.
|
||||
:::
|
||||
|
||||
---
|
||||
|
||||
## Auxiliary Task Fallback
|
||||
|
||||
Hermes uses separate lightweight models for side tasks. Each task has its own provider resolution chain that acts as a built-in fallback system.
|
||||
|
||||
### Tasks with Independent Provider Resolution
|
||||
|
||||
| Task | What It Does | Config Key |
|
||||
|------|-------------|-----------|
|
||||
| Vision | Image analysis, browser screenshots | `auxiliary.vision` |
|
||||
| Web Extract | Web page summarization | `auxiliary.web_extract` |
|
||||
| Compression | Context compression summaries | `auxiliary.compression` or `compression.summary_provider` |
|
||||
| Session Search | Past session summarization | `auxiliary.session_search` |
|
||||
| Skills Hub | Skill search and discovery | `auxiliary.skills_hub` |
|
||||
| MCP | MCP helper operations | `auxiliary.mcp` |
|
||||
| Memory Flush | Memory consolidation | `auxiliary.flush_memories` |
|
||||
|
||||
### Auto-Detection Chain
|
||||
|
||||
When a task's provider is set to `"auto"` (the default), Hermes tries providers in order until one works:
|
||||
|
||||
**For text tasks (compression, web extract, etc.):**
|
||||
|
||||
```text
|
||||
OpenRouter → Nous Portal → Custom endpoint → Codex OAuth →
|
||||
API-key providers (z.ai, Kimi, MiniMax, Anthropic) → give up
|
||||
```
|
||||
|
||||
**For vision tasks:**
|
||||
|
||||
```text
|
||||
Main provider (if vision-capable) → OpenRouter → Nous Portal →
|
||||
Codex OAuth → Anthropic → Custom endpoint → give up
|
||||
```
|
||||
|
||||
If the resolved provider fails at call time, Hermes also has an internal retry: if the provider is not OpenRouter and no explicit `base_url` is set, it tries OpenRouter as a last-resort fallback.
|
||||
|
||||
### Configuring Auxiliary Providers
|
||||
|
||||
Each task can be configured independently in `config.yaml`:
|
||||
|
||||
```yaml
|
||||
auxiliary:
|
||||
vision:
|
||||
provider: "auto" # auto | openrouter | nous | codex | main | anthropic
|
||||
model: "" # e.g. "openai/gpt-4o"
|
||||
base_url: "" # direct endpoint (takes precedence over provider)
|
||||
api_key: "" # API key for base_url
|
||||
|
||||
web_extract:
|
||||
provider: "auto"
|
||||
model: ""
|
||||
|
||||
compression:
|
||||
provider: "auto"
|
||||
model: ""
|
||||
|
||||
session_search:
|
||||
provider: "auto"
|
||||
model: ""
|
||||
|
||||
skills_hub:
|
||||
provider: "auto"
|
||||
model: ""
|
||||
|
||||
mcp:
|
||||
provider: "auto"
|
||||
model: ""
|
||||
|
||||
flush_memories:
|
||||
provider: "auto"
|
||||
model: ""
|
||||
```
|
||||
|
||||
Or via environment variables:
|
||||
|
||||
```bash
|
||||
AUXILIARY_VISION_PROVIDER=openrouter
|
||||
AUXILIARY_VISION_MODEL=openai/gpt-4o
|
||||
AUXILIARY_WEB_EXTRACT_PROVIDER=nous
|
||||
CONTEXT_COMPRESSION_PROVIDER=main
|
||||
CONTEXT_COMPRESSION_MODEL=google/gemini-3-flash-preview
|
||||
```
|
||||
|
||||
### Provider Options for Auxiliary Tasks
|
||||
|
||||
| Provider | Description | Requirements |
|
||||
|----------|-------------|-------------|
|
||||
| `"auto"` | Try providers in order until one works (default) | At least one provider configured |
|
||||
| `"openrouter"` | Force OpenRouter | `OPENROUTER_API_KEY` |
|
||||
| `"nous"` | Force Nous Portal | `hermes login` |
|
||||
| `"codex"` | Force Codex OAuth | `hermes model` → Codex |
|
||||
| `"main"` | Use whatever provider the main agent uses | Active main provider configured |
|
||||
| `"anthropic"` | Force Anthropic native | `ANTHROPIC_API_KEY` or Claude Code credentials |
|
||||
|
||||
### Direct Endpoint Override
|
||||
|
||||
For any auxiliary task, setting `base_url` bypasses provider resolution entirely and sends requests directly to that endpoint:
|
||||
|
||||
```yaml
|
||||
auxiliary:
|
||||
vision:
|
||||
base_url: "http://localhost:1234/v1"
|
||||
api_key: "local-key"
|
||||
model: "qwen2.5-vl"
|
||||
```
|
||||
|
||||
`base_url` takes precedence over `provider`. Hermes uses the configured `api_key` for authentication, falling back to `OPENAI_API_KEY` if not set. It does **not** reuse `OPENROUTER_API_KEY` for custom endpoints.
|
||||
|
||||
---
|
||||
|
||||
## Context Compression Fallback
|
||||
|
||||
Context compression has a legacy configuration path in addition to the auxiliary system:
|
||||
|
||||
```yaml
|
||||
compression:
|
||||
summary_provider: "auto" # auto | openrouter | nous | main
|
||||
summary_model: "google/gemini-3-flash-preview"
|
||||
```
|
||||
|
||||
This is equivalent to configuring `auxiliary.compression.provider` and `auxiliary.compression.model`. If both are set, the `auxiliary.compression` values take precedence.
|
||||
|
||||
If no provider is available for compression, Hermes drops middle conversation turns without generating a summary rather than failing the session.
|
||||
|
||||
---
|
||||
|
||||
## Delegation Provider Override
|
||||
|
||||
Subagents spawned by `delegate_task` do **not** use the primary fallback model. However, they can be routed to a different provider:model pair for cost optimization:
|
||||
|
||||
```yaml
|
||||
delegation:
|
||||
provider: "openrouter" # override provider for all subagents
|
||||
model: "google/gemini-3-flash-preview" # override model
|
||||
# base_url: "http://localhost:1234/v1" # or use a direct endpoint
|
||||
# api_key: "local-key"
|
||||
```
|
||||
|
||||
See [Subagent Delegation](/docs/user-guide/features/delegation) for full configuration details.
|
||||
|
||||
---
|
||||
|
||||
## Cron Job Providers
|
||||
|
||||
Cron jobs run with whatever provider is configured at execution time. They do not support a fallback model. To use a different provider for cron jobs, configure `provider` and `model` overrides on the cron job itself:
|
||||
|
||||
```python
|
||||
cronjob(
|
||||
action="create",
|
||||
schedule="every 2h",
|
||||
prompt="Check server status",
|
||||
provider="openrouter",
|
||||
model="google/gemini-3-flash-preview"
|
||||
)
|
||||
```
|
||||
|
||||
See [Scheduled Tasks (Cron)](/docs/user-guide/features/cron) for full configuration details.
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
| Feature | Fallback Mechanism | Config Location |
|
||||
|---------|-------------------|----------------|
|
||||
| Main agent model | `fallback_model` in config.yaml — one-shot failover on errors | `fallback_model:` (top-level) |
|
||||
| Vision | Auto-detection chain + internal OpenRouter retry | `auxiliary.vision` |
|
||||
| Web extraction | Auto-detection chain + internal OpenRouter retry | `auxiliary.web_extract` |
|
||||
| Context compression | Auto-detection chain, degrades to no-summary if unavailable | `auxiliary.compression` or `compression.summary_provider` |
|
||||
| Session search | Auto-detection chain | `auxiliary.session_search` |
|
||||
| Skills hub | Auto-detection chain | `auxiliary.skills_hub` |
|
||||
| MCP helpers | Auto-detection chain | `auxiliary.mcp` |
|
||||
| Memory flush | Auto-detection chain | `auxiliary.flush_memories` |
|
||||
| Delegation | Provider override only (no automatic fallback) | `delegation.provider` / `delegation.model` |
|
||||
| Cron jobs | Per-job provider override only (no automatic fallback) | Per-job `provider` / `model` |
|
||||
@@ -194,3 +194,7 @@ provider_routing:
|
||||
## Default Behavior
|
||||
|
||||
When no `provider_routing` section is configured (the default), OpenRouter uses its own default routing logic, which generally balances cost and availability automatically.
|
||||
|
||||
:::tip Provider Routing vs. Fallback Models
|
||||
Provider routing controls which **sub-providers within OpenRouter** handle your requests. For automatic failover to an entirely different provider when your primary model fails, see [Fallback Providers](/docs/user-guide/features/fallback-providers).
|
||||
:::
|
||||
|
||||
@@ -8,6 +8,21 @@ description: "Set up Hermes Agent as a Discord bot"
|
||||
|
||||
Hermes Agent integrates with Discord as a bot, letting you chat with your AI assistant through direct messages or server channels. The bot receives your messages, processes them through the Hermes Agent pipeline (including tool use, memory, and reasoning), and responds in real time. It supports text, voice messages, file attachments, and slash commands.
|
||||
|
||||
Before setup, here's the part most people want to know: how Hermes behaves once it's in your server.
|
||||
|
||||
## How Hermes Behaves
|
||||
|
||||
| Context | Behavior |
|
||||
|---------|----------|
|
||||
| **DMs** | Hermes responds to every message. No `@mention` needed. |
|
||||
| **Server channels** | By default, Hermes only responds when you `@mention` it. If you post in a channel without mentioning it, Hermes ignores the message. |
|
||||
| **Free-response channels** | You can make specific channels mention-free with `DISCORD_FREE_RESPONSE_CHANNELS`, or disable mentions globally with `DISCORD_REQUIRE_MENTION=false`. |
|
||||
| **Threads** | Hermes replies in the same thread. Mention rules still apply unless that thread or its parent channel is configured as free-response. |
|
||||
|
||||
:::tip
|
||||
If you want a normal shared bot channel where people can talk to Hermes without tagging it every time, add that channel to `DISCORD_FREE_RESPONSE_CHANNELS`.
|
||||
:::
|
||||
|
||||
This guide walks you through the full setup process — from creating your bot on Discord's Developer Portal to sending your first message.
|
||||
|
||||
## Step 1: Create a Discord Application
|
||||
@@ -200,12 +215,6 @@ DISCORD_HOME_CHANNEL_NAME="#bot-updates"
|
||||
|
||||
Replace the ID with the actual channel ID (right-click → Copy Channel ID with Developer Mode on).
|
||||
|
||||
## Bot Behavior
|
||||
|
||||
- **Server channels**: By default the bot requires an `@mention` before it responds in server channels. You can disable that globally with `DISCORD_REQUIRE_MENTION=false` or allow specific channels to be mention-free via `DISCORD_FREE_RESPONSE_CHANNELS`.
|
||||
- **Direct messages**: DMs always work, even without the Message Content Intent enabled (Discord exempts DMs from this requirement). However, you should still enable the intent for server channel support.
|
||||
- **Conversations**: Each channel or DM maintains its own conversation context.
|
||||
|
||||
## Voice Messages
|
||||
|
||||
Hermes Agent supports Discord voice messages:
|
||||
|
||||
@@ -181,6 +181,63 @@ When enabled, the bot sends status messages as it works:
|
||||
🐍 execute_code...
|
||||
```
|
||||
|
||||
## Background Sessions
|
||||
|
||||
Run a prompt in a separate background session so the agent works on it independently while your main chat stays responsive:
|
||||
|
||||
```
|
||||
/background Check all servers in the cluster and report any that are down
|
||||
```
|
||||
|
||||
Hermes confirms immediately:
|
||||
|
||||
```
|
||||
🔄 Background task started: "Check all servers in the cluster..."
|
||||
Task ID: bg_143022_a1b2c3
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
Each `/background` prompt spawns a **separate agent instance** that runs asynchronously:
|
||||
|
||||
- **Isolated session** — the background agent has its own session with its own conversation history. It has no knowledge of your current chat context and receives only the prompt you provide.
|
||||
- **Same configuration** — inherits your model, provider, toolsets, reasoning settings, and provider routing from the current gateway setup.
|
||||
- **Non-blocking** — your main chat stays fully interactive. Send messages, run other commands, or start more background tasks while it works.
|
||||
- **Result delivery** — when the task finishes, the result is sent back to the **same chat or channel** where you issued the command, prefixed with "✅ Background task complete". If it fails, you'll see "❌ Background task failed" with the error.
|
||||
|
||||
### Background Process Notifications
|
||||
|
||||
When the agent running a background session uses `terminal(background=true)` to start long-running processes (servers, builds, etc.), the gateway can push status updates to your chat. Control this with `display.background_process_notifications` in `~/.hermes/config.yaml`:
|
||||
|
||||
```yaml
|
||||
display:
|
||||
background_process_notifications: all # all | result | error | off
|
||||
```
|
||||
|
||||
| Mode | What you receive |
|
||||
|------|-----------------|
|
||||
| `all` | Running-output updates **and** the final completion message (default) |
|
||||
| `result` | Only the final completion message (regardless of exit code) |
|
||||
| `error` | Only the final message when the exit code is non-zero |
|
||||
| `off` | No process watcher messages at all |
|
||||
|
||||
You can also set this via environment variable:
|
||||
|
||||
```bash
|
||||
HERMES_BACKGROUND_NOTIFICATIONS=result
|
||||
```
|
||||
|
||||
### Use Cases
|
||||
|
||||
- **Server monitoring** — "/background Check the health of all services and alert me if anything is down"
|
||||
- **Long builds** — "/background Build and deploy the staging environment" while you continue chatting
|
||||
- **Research tasks** — "/background Research competitor pricing and summarize in a table"
|
||||
- **File operations** — "/background Organize the photos in ~/Downloads by date into folders"
|
||||
|
||||
:::tip
|
||||
Background tasks on messaging platforms are fire-and-forget — you don't need to wait or check on them. Results arrive in the same chat automatically when the task finishes.
|
||||
:::
|
||||
|
||||
## Service Management
|
||||
|
||||
### Linux (systemd)
|
||||
|
||||
@@ -193,8 +193,8 @@ Understanding how Hermes behaves in different contexts:
|
||||
| Context | Behavior |
|
||||
|---------|----------|
|
||||
| **DMs** | Bot responds to every message — no @mention needed |
|
||||
| **Channels** | Bot **only responds when @mentioned** (e.g., `@Hermes Agent what time is it?`) |
|
||||
| **Threads** | Bot replies in threads when the triggering message is in a thread |
|
||||
| **Channels** | Bot **only responds when @mentioned** (e.g., `@Hermes Agent what time is it?`). In channels, Hermes replies in a thread attached to that message. |
|
||||
| **Threads** | If you @mention Hermes inside an existing thread, it replies in that same thread. |
|
||||
|
||||
:::tip
|
||||
In channels, always @mention the bot. Simply typing a message without mentioning it will be ignored.
|
||||
|
||||
@@ -91,6 +91,7 @@ const sidebars: SidebarsConfig = {
|
||||
'user-guide/features/mcp',
|
||||
'user-guide/features/honcho',
|
||||
'user-guide/features/provider-routing',
|
||||
'user-guide/features/fallback-providers',
|
||||
],
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user