Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7a26669826 | |||
| 469cd16fe0 | |||
| b1a66d55b4 | |||
| 0d41fb0827 | |||
| 4aef055805 | |||
| f3006ebef9 | |||
| 99ff375f7a | |||
| 125e5ef089 | |||
| 4a630c2071 | |||
| 7b18eeee9b |
@@ -19,6 +19,9 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ripgrep
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PurePosixPath, PureWindowsPath
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
@@ -148,6 +148,62 @@ MODAL_INCOMPATIBLE_TASKS = {
|
||||
# Tar extraction helper
|
||||
# =============================================================================
|
||||
|
||||
def _normalize_tar_member_parts(member_name: str) -> list:
|
||||
"""Return safe path components for a tar member or raise ValueError."""
|
||||
normalized_name = member_name.replace("\\", "/")
|
||||
posix_path = PurePosixPath(normalized_name)
|
||||
windows_path = PureWindowsPath(member_name)
|
||||
|
||||
if (
|
||||
not normalized_name
|
||||
or posix_path.is_absolute()
|
||||
or windows_path.is_absolute()
|
||||
or windows_path.drive
|
||||
):
|
||||
raise ValueError(f"Unsafe archive member path: {member_name}")
|
||||
|
||||
parts = [part for part in posix_path.parts if part not in ("", ".")]
|
||||
if not parts or any(part == ".." for part in parts):
|
||||
raise ValueError(f"Unsafe archive member path: {member_name}")
|
||||
return parts
|
||||
|
||||
|
||||
def _safe_extract_tar(tar: tarfile.TarFile, target_dir: Path) -> None:
|
||||
"""Extract a tar archive without allowing traversal or link entries."""
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
target_root = target_dir.resolve()
|
||||
|
||||
for member in tar.getmembers():
|
||||
parts = _normalize_tar_member_parts(member.name)
|
||||
target = target_dir.joinpath(*parts)
|
||||
target_real = target.resolve(strict=False)
|
||||
|
||||
try:
|
||||
target_real.relative_to(target_root)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Unsafe archive member path: {member.name}") from exc
|
||||
|
||||
if member.isdir():
|
||||
target_real.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
|
||||
if not member.isfile():
|
||||
raise ValueError(f"Unsupported archive member type: {member.name}")
|
||||
|
||||
target_real.parent.mkdir(parents=True, exist_ok=True)
|
||||
extracted = tar.extractfile(member)
|
||||
if extracted is None:
|
||||
raise ValueError(f"Cannot read archive member: {member.name}")
|
||||
|
||||
with extracted, open(target_real, "wb") as dst:
|
||||
shutil.copyfileobj(extracted, dst)
|
||||
|
||||
try:
|
||||
os.chmod(target_real, member.mode & 0o777)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
"""Extract a base64-encoded tar.gz archive into target_dir."""
|
||||
if not b64_data:
|
||||
@@ -155,7 +211,7 @@ def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
raw = base64.b64decode(b64_data)
|
||||
buf = io.BytesIO(raw)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
tar.extractall(path=str(target_dir))
|
||||
_safe_extract_tar(tar, target_dir)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -20,6 +20,7 @@ Requires:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -370,7 +371,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:].strip()
|
||||
if token == self._api_key:
|
||||
if hmac.compare_digest(token, self._api_key):
|
||||
return None # Auth OK
|
||||
|
||||
return web.json_response(
|
||||
|
||||
@@ -124,7 +124,14 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL targets a private/internal network (SSRF protection).
|
||||
"""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging as _logging
|
||||
@@ -232,7 +239,14 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached audio file as a string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL targets a private/internal network (SSRF protection).
|
||||
"""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging as _logging
|
||||
@@ -1105,6 +1119,22 @@ class BasePlatformAdapter(ABC):
|
||||
logger.error("[%s] Fallback send also failed: %s", self.name, fallback_result.error)
|
||||
return fallback_result
|
||||
|
||||
@staticmethod
|
||||
def _merge_caption(existing_text: Optional[str], new_text: str) -> str:
|
||||
"""Merge a new caption into existing text, avoiding duplicates.
|
||||
|
||||
Uses line-by-line exact match (not substring) to prevent false positives
|
||||
where a shorter caption is silently dropped because it appears as a
|
||||
substring of a longer one (e.g. "Meeting" inside "Meeting agenda").
|
||||
Whitespace is normalised for comparison.
|
||||
"""
|
||||
if not existing_text:
|
||||
return new_text
|
||||
existing_captions = [c.strip() for c in existing_text.split("\n\n")]
|
||||
if new_text.strip() not in existing_captions:
|
||||
return f"{existing_text}\n\n{new_text}".strip()
|
||||
return existing_text
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
@@ -1164,10 +1194,7 @@ class BasePlatformAdapter(ABC):
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
existing.text = self._merge_caption(existing.text, event.text)
|
||||
else:
|
||||
self._pending_messages[session_key] = event
|
||||
return # Don't interrupt now - will run after current task completes
|
||||
|
||||
@@ -55,6 +55,7 @@ from gateway.platforms.base import (
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
from tools.url_safety import is_safe_url
|
||||
|
||||
|
||||
def _clean_discord_id(entry: str) -> str:
|
||||
@@ -1285,6 +1286,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not is_safe_url(image_url):
|
||||
logger.warning("[%s] Blocked unsafe image URL during Discord send_image", self.name)
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
@@ -2065,10 +2065,7 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text.split("\n\n"):
|
||||
existing.text = f"{existing.text}\n\n{event.text}"
|
||||
existing.text = self._merge_caption(existing.text, event.text)
|
||||
existing.timestamp = event.timestamp
|
||||
if event.message_id:
|
||||
existing.message_id = event.message_id
|
||||
@@ -2112,6 +2109,10 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
default_ext: str,
|
||||
preferred_name: str,
|
||||
) -> tuple[str, str]:
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(file_url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {file_url[:80]}")
|
||||
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
|
||||
@@ -586,6 +586,11 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Download an image URL and upload it to Matrix."""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(image_url):
|
||||
logger.warning("Matrix: blocked unsafe image URL (SSRF protection)")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
|
||||
|
||||
try:
|
||||
# Try aiohttp first (always available), fall back to httpx
|
||||
try:
|
||||
|
||||
@@ -407,6 +407,11 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
kind: str = "file",
|
||||
) -> SendResult:
|
||||
"""Download a URL and upload it as a file attachment."""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
logger.warning("Mattermost: blocked unsafe URL (SSRF protection)")
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
|
||||
@@ -595,6 +595,11 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(image_url):
|
||||
logger.warning("[Slack] Blocked unsafe image URL (SSRF protection)")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
|
||||
@@ -1632,7 +1632,12 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(image_url):
|
||||
logger.warning("[%s] Blocked unsafe image URL (SSRF protection)", self.name)
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs (up to ~5MB)
|
||||
_photo_thread = metadata.get("thread_id") if metadata else None
|
||||
@@ -2222,10 +2227,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
existing.text = self._merge_caption(existing.text, event.text)
|
||||
|
||||
prior_task = self._pending_photo_batch_tasks.get(batch_key)
|
||||
if prior_task and not prior_task.done():
|
||||
@@ -2415,11 +2417,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if existing.text:
|
||||
if event.text not in existing.text.split("\n\n"):
|
||||
existing.text = f"{existing.text}\n\n{event.text}"
|
||||
else:
|
||||
existing.text = event.text
|
||||
existing.text = self._merge_caption(existing.text, event.text)
|
||||
|
||||
prior_task = self._media_group_tasks.get(media_group_id)
|
||||
if prior_task:
|
||||
|
||||
@@ -76,8 +76,17 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
self._routes: Dict[str, dict] = dict(self._static_routes)
|
||||
self._runner = None
|
||||
|
||||
# Delivery info keyed by session chat_id — consumed by send()
|
||||
# Delivery info keyed by session chat_id.
|
||||
#
|
||||
# Read by every send() invocation for the chat_id (status messages
|
||||
# AND the final response). Cleaned up via TTL on each POST so the
|
||||
# dict stays bounded — see _prune_delivery_info(). Do NOT pop on
|
||||
# send(), or interim status messages (e.g. fallback notifications,
|
||||
# context-pressure warnings) will consume the entry before the
|
||||
# final response arrives, causing the response to silently fall
|
||||
# back to the "log" deliver type.
|
||||
self._delivery_info: Dict[str, dict] = {}
|
||||
self._delivery_info_created: Dict[str, float] = {}
|
||||
|
||||
# Reference to gateway runner for cross-platform delivery (set externally)
|
||||
self.gateway_runner = None
|
||||
@@ -160,10 +169,14 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
) -> SendResult:
|
||||
"""Deliver the agent's response to the configured destination.
|
||||
|
||||
chat_id is ``webhook:{route}:{delivery_id}`` — we pop the delivery
|
||||
info stored during webhook receipt so it doesn't leak memory.
|
||||
chat_id is ``webhook:{route}:{delivery_id}``. The delivery info
|
||||
stored during webhook receipt is read with ``.get()`` (not popped)
|
||||
so that interim status messages emitted before the final response
|
||||
— fallback-model notifications, context-pressure warnings, etc. —
|
||||
do not consume the entry and silently downgrade the final response
|
||||
to the ``log`` deliver type. TTL cleanup happens on POST.
|
||||
"""
|
||||
delivery = self._delivery_info.pop(chat_id, {})
|
||||
delivery = self._delivery_info.get(chat_id, {})
|
||||
deliver_type = delivery.get("deliver", "log")
|
||||
|
||||
if deliver_type == "log":
|
||||
@@ -190,6 +203,23 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
success=False, error=f"Unknown deliver type: {deliver_type}"
|
||||
)
|
||||
|
||||
def _prune_delivery_info(self, now: float) -> None:
|
||||
"""Drop delivery_info entries older than the idempotency TTL.
|
||||
|
||||
Mirrors the cleanup pattern used for ``_seen_deliveries``. Called
|
||||
on each POST so the dict size is bounded by ``rate_limit * TTL``
|
||||
even if many webhooks fire and never receive a final response.
|
||||
"""
|
||||
cutoff = now - self._idempotency_ttl
|
||||
stale = [
|
||||
k
|
||||
for k, t in self._delivery_info_created.items()
|
||||
if t < cutoff
|
||||
]
|
||||
for k in stale:
|
||||
self._delivery_info.pop(k, None)
|
||||
self._delivery_info_created.pop(k, None)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
return {"name": chat_id, "type": "webhook"}
|
||||
|
||||
@@ -382,7 +412,9 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
# same route get independent agent runs (not queued/interrupted).
|
||||
session_chat_id = f"webhook:{route_name}:{delivery_id}"
|
||||
|
||||
# Store delivery info for send() — consumed (popped) on delivery
|
||||
# Store delivery info for send(). Read by every send() invocation
|
||||
# for this chat_id (interim status messages and the final response),
|
||||
# so we do NOT pop on send. TTL-based cleanup keeps the dict bounded.
|
||||
deliver_config = {
|
||||
"deliver": route_config.get("deliver", "log"),
|
||||
"deliver_extra": self._render_delivery_extra(
|
||||
@@ -391,6 +423,8 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
"payload": payload,
|
||||
}
|
||||
self._delivery_info[session_chat_id] = deliver_config
|
||||
self._delivery_info_created[session_chat_id] = now
|
||||
self._prune_delivery_info(now)
|
||||
|
||||
# Build source and event
|
||||
source = self.build_source(
|
||||
|
||||
@@ -910,6 +910,10 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
url: str,
|
||||
max_bytes: int,
|
||||
) -> Tuple[bytes, Dict[str, str]]:
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {url[:80]}")
|
||||
|
||||
if not HTTPX_AVAILABLE:
|
||||
raise RuntimeError("httpx is required for WeCom media download")
|
||||
|
||||
|
||||
+28
-15
@@ -1987,10 +1987,7 @@ class GatewayRunner:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
||||
else:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
else:
|
||||
@@ -3345,25 +3342,36 @@ class GatewayRunner:
|
||||
"""Handle /status command."""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
|
||||
|
||||
connected_platforms = [p.value for p in self.adapters.keys()]
|
||||
|
||||
|
||||
# Check if there's an active agent
|
||||
session_key = session_entry.session_key
|
||||
is_running = session_key in self._running_agents
|
||||
|
||||
|
||||
title = None
|
||||
if self._session_db:
|
||||
try:
|
||||
title = self._session_db.get_session_title(session_entry.session_id)
|
||||
except Exception:
|
||||
title = None
|
||||
|
||||
lines = [
|
||||
"📊 **Hermes Gateway Status**",
|
||||
"",
|
||||
f"**Session ID:** `{session_entry.session_id[:12]}...`",
|
||||
f"**Session ID:** `{session_entry.session_id}`",
|
||||
]
|
||||
if title:
|
||||
lines.append(f"**Title:** {title}")
|
||||
lines.extend([
|
||||
f"**Created:** {session_entry.created_at.strftime('%Y-%m-%d %H:%M')}",
|
||||
f"**Last Activity:** {session_entry.updated_at.strftime('%Y-%m-%d %H:%M')}",
|
||||
f"**Tokens:** {session_entry.total_tokens:,}",
|
||||
f"**Agent Running:** {'Yes ⚡' if is_running else 'No'}",
|
||||
"",
|
||||
f"**Connected Platforms:** {', '.join(connected_platforms)}",
|
||||
]
|
||||
|
||||
])
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_stop_command(self, event: MessageEvent) -> str:
|
||||
@@ -4913,8 +4921,8 @@ class GatewayRunner:
|
||||
cycle = ["off", "new", "all", "verbose"]
|
||||
descriptions = {
|
||||
"off": "⚙️ Tool progress: **OFF** — no tool activity shown.",
|
||||
"new": "⚙️ Tool progress: **NEW** — shown when tool changes (short previews).",
|
||||
"all": "⚙️ Tool progress: **ALL** — every tool call shown (short previews).",
|
||||
"new": "⚙️ Tool progress: **NEW** — shown when tool changes (preview length: `display.tool_preview_length`, default 40).",
|
||||
"all": "⚙️ Tool progress: **ALL** — every tool call shown (preview length: `display.tool_preview_length`, default 40).",
|
||||
"verbose": "⚙️ Tool progress: **VERBOSE** — every tool call with full arguments.",
|
||||
}
|
||||
|
||||
@@ -6327,10 +6335,15 @@ class GatewayRunner:
|
||||
progress_queue.put(msg)
|
||||
return
|
||||
|
||||
# "all" / "new" modes: short preview, always truncated (40 chars)
|
||||
# "all" / "new" modes: short preview, respects tool_preview_length
|
||||
# config (defaults to 40 chars when unset to keep gateway messages
|
||||
# compact — unlike CLI spinners, these persist as permanent messages).
|
||||
if preview:
|
||||
if len(preview) > 40:
|
||||
preview = preview[:37] + "..."
|
||||
from agent.display import get_tool_preview_max_len
|
||||
_pl = get_tool_preview_max_len()
|
||||
_cap = _pl if _pl > 0 else 40
|
||||
if len(preview) > _cap:
|
||||
preview = preview[:_cap - 3] + "..."
|
||||
msg = f"{emoji} {tool_name}: \"{preview}\""
|
||||
else:
|
||||
msg = f"{emoji} {tool_name}..."
|
||||
|
||||
+4
-15
@@ -37,7 +37,7 @@ from typing import Any, Dict, List, Optional
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from hermes_cli.config import get_hermes_home, get_config_path
|
||||
from hermes_cli.config import get_hermes_home, get_config_path, read_raw_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -2214,14 +2214,7 @@ def _update_config_for_provider(
|
||||
config_path = get_config_path()
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config: Dict[str, Any] = {}
|
||||
if config_path.exists():
|
||||
try:
|
||||
loaded = yaml.safe_load(config_path.read_text()) or {}
|
||||
if isinstance(loaded, dict):
|
||||
config = loaded
|
||||
except Exception:
|
||||
config = {}
|
||||
config = read_raw_config()
|
||||
|
||||
current_model = config.get("model")
|
||||
if isinstance(current_model, dict):
|
||||
@@ -2258,12 +2251,8 @@ def _reset_config_provider() -> Path:
|
||||
if not config_path.exists():
|
||||
return config_path
|
||||
|
||||
try:
|
||||
config = yaml.safe_load(config_path.read_text()) or {}
|
||||
except Exception:
|
||||
return config_path
|
||||
|
||||
if not isinstance(config, dict):
|
||||
config = read_raw_config()
|
||||
if not config:
|
||||
return config_path
|
||||
|
||||
model = config.get("model")
|
||||
|
||||
@@ -293,14 +293,8 @@ def _resolve_config_gates() -> set[str]:
|
||||
if not gated:
|
||||
return set()
|
||||
try:
|
||||
import yaml
|
||||
from hermes_constants import get_hermes_home
|
||||
config_path = str(get_hermes_home() / "config.yaml")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
else:
|
||||
cfg = {}
|
||||
from hermes_cli.config import read_raw_config
|
||||
cfg = read_raw_config()
|
||||
except Exception:
|
||||
return set()
|
||||
result: set[str] = set()
|
||||
|
||||
@@ -17,7 +17,7 @@ Or manually:
|
||||
|
||||
```bash
|
||||
hermes config set memory.provider supermemory
|
||||
echo 'SUPERMEMORY_API_KEY=your-key-here' >> ~/.hermes/.env
|
||||
echo 'SUPERMEMORY_API_KEY=***' >> ~/.hermes/.env
|
||||
```
|
||||
|
||||
## Config
|
||||
@@ -26,15 +26,23 @@ Config file: `$HERMES_HOME/supermemory.json`
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `container_tag` | `hermes` | Container tag used for search and writes |
|
||||
| `container_tag` | `hermes` | Container tag used for search and writes. Supports `{identity}` template for profile-scoped tags (e.g. `hermes-{identity}` → `hermes-coder`). |
|
||||
| `auto_recall` | `true` | Inject relevant memory context before turns |
|
||||
| `auto_capture` | `true` | Store cleaned user-assistant turns after each response |
|
||||
| `max_recall_results` | `10` | Max recalled items to format into context |
|
||||
| `profile_frequency` | `50` | Include profile facts on first turn and every N turns |
|
||||
| `capture_mode` | `all` | Skip tiny or trivial turns by default |
|
||||
| `search_mode` | `hybrid` | Search mode: `hybrid` (profile + memories), `memories` (memories only), `documents` (documents only) |
|
||||
| `entity_context` | built-in default | Extraction guidance passed to Supermemory |
|
||||
| `api_timeout` | `5.0` | Timeout for SDK and ingest requests |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `SUPERMEMORY_API_KEY` | API key (required) |
|
||||
| `SUPERMEMORY_CONTAINER_TAG` | Override container tag (takes priority over config file) |
|
||||
|
||||
## Tools
|
||||
|
||||
| Tool | Description |
|
||||
@@ -52,3 +60,40 @@ When enabled, Hermes can:
|
||||
- store cleaned conversation turns after each completed response
|
||||
- ingest the full session on session end for richer graph updates
|
||||
- expose explicit tools for search, store, forget, and profile access
|
||||
|
||||
## Profile-Scoped Containers
|
||||
|
||||
Use `{identity}` in the `container_tag` to scope memories per Hermes profile:
|
||||
|
||||
```json
|
||||
{
|
||||
"container_tag": "hermes-{identity}"
|
||||
}
|
||||
```
|
||||
|
||||
For a profile named `coder`, this resolves to `hermes-coder`. The default profile resolves to `hermes-default`. Without `{identity}`, all profiles share the same container.
|
||||
|
||||
## Multi-Container Mode
|
||||
|
||||
For advanced setups (e.g. OpenClaw-style multi-workspace), you can enable custom container tags so the agent can read/write across multiple named containers:
|
||||
|
||||
```json
|
||||
{
|
||||
"container_tag": "hermes",
|
||||
"enable_custom_container_tags": true,
|
||||
"custom_containers": ["project-alpha", "project-beta", "shared-knowledge"],
|
||||
"custom_container_instructions": "Use project-alpha for coding tasks, project-beta for research, and shared-knowledge for team-wide facts."
|
||||
}
|
||||
```
|
||||
|
||||
When enabled:
|
||||
- `supermemory_search`, `supermemory_store`, `supermemory_forget`, and `supermemory_profile` accept an optional `container_tag` parameter
|
||||
- The tag must be in the whitelist: primary container + `custom_containers`
|
||||
- Automatic operations (turn sync, prefetch, memory write mirroring, session ingest) always use the **primary** container only
|
||||
- Custom container instructions are injected into the system prompt
|
||||
|
||||
## Support
|
||||
|
||||
- [Supermemory Discord](https://supermemory.link/discord)
|
||||
- [support@supermemory.com](mailto:support@supermemory.com)
|
||||
- [supermemory.ai](https://supermemory.ai)
|
||||
|
||||
@@ -26,6 +26,8 @@ _DEFAULT_CONTAINER_TAG = "hermes"
|
||||
_DEFAULT_MAX_RECALL_RESULTS = 10
|
||||
_DEFAULT_PROFILE_FREQUENCY = 50
|
||||
_DEFAULT_CAPTURE_MODE = "all"
|
||||
_DEFAULT_SEARCH_MODE = "hybrid"
|
||||
_VALID_SEARCH_MODES = ("hybrid", "memories", "documents")
|
||||
_DEFAULT_API_TIMEOUT = 5.0
|
||||
_MIN_CAPTURE_LENGTH = 10
|
||||
_MAX_ENTITY_CONTEXT_LENGTH = 1500
|
||||
@@ -59,8 +61,12 @@ def _default_config() -> dict:
|
||||
"max_recall_results": _DEFAULT_MAX_RECALL_RESULTS,
|
||||
"profile_frequency": _DEFAULT_PROFILE_FREQUENCY,
|
||||
"capture_mode": _DEFAULT_CAPTURE_MODE,
|
||||
"search_mode": _DEFAULT_SEARCH_MODE,
|
||||
"entity_context": _DEFAULT_ENTITY_CONTEXT,
|
||||
"api_timeout": _DEFAULT_API_TIMEOUT,
|
||||
"enable_custom_container_tags": False,
|
||||
"custom_containers": [],
|
||||
"custom_container_instructions": "",
|
||||
}
|
||||
|
||||
|
||||
@@ -100,7 +106,10 @@ def _load_supermemory_config(hermes_home: str) -> dict:
|
||||
except Exception:
|
||||
logger.debug("Failed to parse %s", config_path, exc_info=True)
|
||||
|
||||
config["container_tag"] = _sanitize_tag(str(config.get("container_tag", _DEFAULT_CONTAINER_TAG)))
|
||||
# Keep raw container_tag — template variables like {identity} are resolved
|
||||
# in initialize(), and _sanitize_tag runs AFTER resolution.
|
||||
raw_tag = str(config.get("container_tag", _DEFAULT_CONTAINER_TAG)).strip()
|
||||
config["container_tag"] = raw_tag if raw_tag else _DEFAULT_CONTAINER_TAG
|
||||
config["auto_recall"] = _as_bool(config.get("auto_recall"), True)
|
||||
config["auto_capture"] = _as_bool(config.get("auto_capture"), True)
|
||||
try:
|
||||
@@ -112,11 +121,23 @@ def _load_supermemory_config(hermes_home: str) -> dict:
|
||||
except Exception:
|
||||
config["profile_frequency"] = _DEFAULT_PROFILE_FREQUENCY
|
||||
config["capture_mode"] = "everything" if config.get("capture_mode") == "everything" else "all"
|
||||
raw_search_mode = str(config.get("search_mode", _DEFAULT_SEARCH_MODE)).strip().lower()
|
||||
config["search_mode"] = raw_search_mode if raw_search_mode in _VALID_SEARCH_MODES else _DEFAULT_SEARCH_MODE
|
||||
config["entity_context"] = _clamp_entity_context(str(config.get("entity_context", _DEFAULT_ENTITY_CONTEXT)))
|
||||
try:
|
||||
config["api_timeout"] = max(0.5, min(15.0, float(config.get("api_timeout", _DEFAULT_API_TIMEOUT))))
|
||||
except Exception:
|
||||
config["api_timeout"] = _DEFAULT_API_TIMEOUT
|
||||
|
||||
# Multi-container support
|
||||
config["enable_custom_container_tags"] = _as_bool(config.get("enable_custom_container_tags"), False)
|
||||
raw_containers = config.get("custom_containers", [])
|
||||
if isinstance(raw_containers, list):
|
||||
config["custom_containers"] = [_sanitize_tag(str(t)) for t in raw_containers if t]
|
||||
else:
|
||||
config["custom_containers"] = []
|
||||
config["custom_container_instructions"] = str(config.get("custom_container_instructions", "")).strip()
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@@ -240,28 +261,41 @@ def _is_trivial_message(text: str) -> bool:
|
||||
|
||||
|
||||
class _SupermemoryClient:
|
||||
def __init__(self, api_key: str, timeout: float, container_tag: str):
|
||||
def __init__(self, api_key: str, timeout: float, container_tag: str, search_mode: str = "hybrid"):
|
||||
from supermemory import Supermemory
|
||||
|
||||
self._api_key = api_key
|
||||
self._container_tag = container_tag
|
||||
self._search_mode = search_mode if search_mode in _VALID_SEARCH_MODES else _DEFAULT_SEARCH_MODE
|
||||
self._timeout = timeout
|
||||
self._client = Supermemory(api_key=api_key, timeout=timeout, max_retries=0)
|
||||
|
||||
def add_memory(self, content: str, metadata: Optional[dict] = None, *, entity_context: str = "") -> dict:
|
||||
kwargs = {
|
||||
def add_memory(self, content: str, metadata: Optional[dict] = None, *,
|
||||
entity_context: str = "", container_tag: Optional[str] = None,
|
||||
custom_id: Optional[str] = None) -> dict:
|
||||
tag = container_tag or self._container_tag
|
||||
kwargs: dict[str, Any] = {
|
||||
"content": content.strip(),
|
||||
"container_tags": [self._container_tag],
|
||||
"container_tags": [tag],
|
||||
}
|
||||
if metadata:
|
||||
kwargs["metadata"] = metadata
|
||||
if entity_context:
|
||||
kwargs["entity_context"] = _clamp_entity_context(entity_context)
|
||||
if custom_id:
|
||||
kwargs["custom_id"] = custom_id
|
||||
result = self._client.documents.add(**kwargs)
|
||||
return {"id": getattr(result, "id", "")}
|
||||
|
||||
def search_memories(self, query: str, *, limit: int = 5) -> list[dict]:
|
||||
response = self._client.search.memories(q=query, container_tag=self._container_tag, limit=limit)
|
||||
def search_memories(self, query: str, *, limit: int = 5,
|
||||
container_tag: Optional[str] = None,
|
||||
search_mode: Optional[str] = None) -> list[dict]:
|
||||
tag = container_tag or self._container_tag
|
||||
mode = search_mode or self._search_mode
|
||||
kwargs: dict[str, Any] = {"q": query, "container_tag": tag, "limit": limit}
|
||||
if mode in _VALID_SEARCH_MODES:
|
||||
kwargs["search_mode"] = mode
|
||||
response = self._client.search.memories(**kwargs)
|
||||
results = []
|
||||
for item in (getattr(response, "results", None) or []):
|
||||
results.append({
|
||||
@@ -273,8 +307,10 @@ class _SupermemoryClient:
|
||||
})
|
||||
return results
|
||||
|
||||
def get_profile(self, query: Optional[str] = None) -> dict:
|
||||
kwargs = {"container_tag": self._container_tag}
|
||||
def get_profile(self, query: Optional[str] = None, *,
|
||||
container_tag: Optional[str] = None) -> dict:
|
||||
tag = container_tag or self._container_tag
|
||||
kwargs: dict[str, Any] = {"container_tag": tag}
|
||||
if query:
|
||||
kwargs["q"] = query
|
||||
response = self._client.profile(**kwargs)
|
||||
@@ -296,18 +332,19 @@ class _SupermemoryClient:
|
||||
})
|
||||
return {"static": static, "dynamic": dynamic, "search_results": search_results}
|
||||
|
||||
def forget_memory(self, memory_id: str) -> None:
|
||||
self._client.memories.forget(container_tag=self._container_tag, id=memory_id)
|
||||
def forget_memory(self, memory_id: str, *, container_tag: Optional[str] = None) -> None:
|
||||
tag = container_tag or self._container_tag
|
||||
self._client.memories.forget(container_tag=tag, id=memory_id)
|
||||
|
||||
def forget_by_query(self, query: str) -> dict:
|
||||
results = self.search_memories(query, limit=5)
|
||||
def forget_by_query(self, query: str, *, container_tag: Optional[str] = None) -> dict:
|
||||
results = self.search_memories(query, limit=5, container_tag=container_tag)
|
||||
if not results:
|
||||
return {"success": False, "message": "No matching memory found to forget."}
|
||||
target = results[0]
|
||||
memory_id = target.get("id", "")
|
||||
if not memory_id:
|
||||
return {"success": False, "message": "Best matching memory has no id."}
|
||||
self.forget_memory(memory_id)
|
||||
self.forget_memory(memory_id, container_tag=container_tag)
|
||||
preview = (target.get("memory") or "")[:100]
|
||||
return {"success": True, "message": f'Forgot: "{preview}"', "id": memory_id}
|
||||
|
||||
@@ -398,11 +435,17 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
self._max_recall_results = _DEFAULT_MAX_RECALL_RESULTS
|
||||
self._profile_frequency = _DEFAULT_PROFILE_FREQUENCY
|
||||
self._capture_mode = _DEFAULT_CAPTURE_MODE
|
||||
self._search_mode = _DEFAULT_SEARCH_MODE
|
||||
self._entity_context = _DEFAULT_ENTITY_CONTEXT
|
||||
self._api_timeout = _DEFAULT_API_TIMEOUT
|
||||
self._hermes_home = ""
|
||||
self._write_enabled = True
|
||||
self._active = False
|
||||
# Multi-container support
|
||||
self._enable_custom_containers = False
|
||||
self._custom_containers: List[str] = []
|
||||
self._custom_container_instructions = ""
|
||||
self._allowed_containers: List[str] = []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -419,16 +462,11 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
return False
|
||||
|
||||
def get_config_schema(self):
|
||||
# Only prompt for the API key during `hermes memory setup`.
|
||||
# All other options are documented for $HERMES_HOME/supermemory.json
|
||||
# or the SUPERMEMORY_CONTAINER_TAG env var.
|
||||
return [
|
||||
{"key": "api_key", "description": "Supermemory API key", "secret": True, "required": True, "env_var": "SUPERMEMORY_API_KEY", "url": "https://supermemory.ai"},
|
||||
{"key": "container_tag", "description": "Container tag for reads and writes", "default": _DEFAULT_CONTAINER_TAG},
|
||||
{"key": "auto_recall", "description": "Enable automatic recall before each turn", "default": "true", "choices": ["true", "false"]},
|
||||
{"key": "auto_capture", "description": "Enable automatic capture after each completed turn", "default": "true", "choices": ["true", "false"]},
|
||||
{"key": "max_recall_results", "description": "Maximum recalled items to inject", "default": str(_DEFAULT_MAX_RECALL_RESULTS)},
|
||||
{"key": "profile_frequency", "description": "Include profile facts on first turn and every N turns", "default": str(_DEFAULT_PROFILE_FREQUENCY)},
|
||||
{"key": "capture_mode", "description": "Capture mode", "default": _DEFAULT_CAPTURE_MODE, "choices": ["all", "everything"]},
|
||||
{"key": "entity_context", "description": "Extraction guidance passed to Supermemory", "default": _DEFAULT_ENTITY_CONTEXT},
|
||||
{"key": "api_timeout", "description": "Timeout in seconds for SDK and ingest calls", "default": str(_DEFAULT_API_TIMEOUT)},
|
||||
]
|
||||
|
||||
def save_config(self, values, hermes_home):
|
||||
@@ -446,14 +484,29 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
self._turn_count = 0
|
||||
self._config = _load_supermemory_config(self._hermes_home)
|
||||
self._api_key = os.environ.get("SUPERMEMORY_API_KEY", "")
|
||||
self._container_tag = self._config["container_tag"]
|
||||
|
||||
# Resolve container tag: env var > config > default.
|
||||
# Supports {identity} template for profile-scoped containers.
|
||||
env_tag = os.environ.get("SUPERMEMORY_CONTAINER_TAG", "").strip()
|
||||
raw_tag = env_tag or self._config["container_tag"]
|
||||
identity = kwargs.get("agent_identity", "default")
|
||||
self._container_tag = _sanitize_tag(raw_tag.replace("{identity}", identity))
|
||||
|
||||
self._auto_recall = self._config["auto_recall"]
|
||||
self._auto_capture = self._config["auto_capture"]
|
||||
self._max_recall_results = self._config["max_recall_results"]
|
||||
self._profile_frequency = self._config["profile_frequency"]
|
||||
self._capture_mode = self._config["capture_mode"]
|
||||
self._search_mode = self._config["search_mode"]
|
||||
self._entity_context = self._config["entity_context"]
|
||||
self._api_timeout = self._config["api_timeout"]
|
||||
|
||||
# Multi-container setup
|
||||
self._enable_custom_containers = self._config["enable_custom_container_tags"]
|
||||
self._custom_containers = self._config["custom_containers"]
|
||||
self._custom_container_instructions = self._config["custom_container_instructions"]
|
||||
self._allowed_containers = [self._container_tag] + list(self._custom_containers)
|
||||
|
||||
agent_context = kwargs.get("agent_context", "")
|
||||
self._write_enabled = agent_context not in ("cron", "flush", "subagent")
|
||||
self._active = bool(self._api_key)
|
||||
@@ -464,6 +517,7 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
api_key=self._api_key,
|
||||
timeout=self._api_timeout,
|
||||
container_tag=self._container_tag,
|
||||
search_mode=self._search_mode,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Supermemory initialization failed", exc_info=True)
|
||||
@@ -476,11 +530,18 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
def system_prompt_block(self) -> str:
|
||||
if not self._active:
|
||||
return ""
|
||||
return (
|
||||
"# Supermemory\n"
|
||||
f"Active. Container: {self._container_tag}.\n"
|
||||
"Use supermemory_search, supermemory_store, supermemory_forget, and supermemory_profile for explicit memory operations."
|
||||
)
|
||||
lines = [
|
||||
"# Supermemory",
|
||||
f"Active. Container: {self._container_tag}.",
|
||||
"Use supermemory_search, supermemory_store, supermemory_forget, and supermemory_profile for explicit memory operations.",
|
||||
]
|
||||
if self._enable_custom_containers and self._custom_containers:
|
||||
tags_str = ", ".join(self._allowed_containers)
|
||||
lines.append(f"\nMulti-container mode enabled. Available containers: {tags_str}.")
|
||||
lines.append("Pass an optional container_tag to supermemory_search, supermemory_store, supermemory_forget, and supermemory_profile to target a specific container.")
|
||||
if self._custom_container_instructions:
|
||||
lines.append(f"\n{self._custom_container_instructions}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
if not self._active or not self._auto_recall or not self._client or not query.strip():
|
||||
@@ -582,22 +643,62 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
thread.join(timeout=5.0)
|
||||
setattr(self, attr_name, None)
|
||||
|
||||
def _resolve_tool_container_tag(self, args: dict) -> Optional[str]:
|
||||
"""Validate and resolve container_tag from tool call args.
|
||||
|
||||
Returns None (use primary) if multi-container is disabled or no tag provided.
|
||||
Returns the validated tag if it's in the allowed list.
|
||||
Raises ValueError if the tag is not whitelisted.
|
||||
"""
|
||||
if not self._enable_custom_containers:
|
||||
return None
|
||||
tag = str(args.get("container_tag") or "").strip()
|
||||
if not tag:
|
||||
return None
|
||||
sanitized = _sanitize_tag(tag)
|
||||
if sanitized not in self._allowed_containers:
|
||||
raise ValueError(
|
||||
f"Container tag '{sanitized}' is not allowed. "
|
||||
f"Allowed: {', '.join(self._allowed_containers)}"
|
||||
)
|
||||
return sanitized
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [STORE_SCHEMA, SEARCH_SCHEMA, FORGET_SCHEMA, PROFILE_SCHEMA]
|
||||
if not self._enable_custom_containers:
|
||||
return [STORE_SCHEMA, SEARCH_SCHEMA, FORGET_SCHEMA, PROFILE_SCHEMA]
|
||||
|
||||
# When multi-container is enabled, add optional container_tag to relevant tools
|
||||
container_param = {
|
||||
"type": "string",
|
||||
"description": f"Optional container tag. Allowed: {', '.join(self._allowed_containers)}. Defaults to primary ({self._container_tag}).",
|
||||
}
|
||||
schemas = []
|
||||
for base in [STORE_SCHEMA, SEARCH_SCHEMA, FORGET_SCHEMA, PROFILE_SCHEMA]:
|
||||
schema = json.loads(json.dumps(base)) # deep copy
|
||||
schema["parameters"]["properties"]["container_tag"] = container_param
|
||||
schemas.append(schema)
|
||||
return schemas
|
||||
|
||||
def _tool_store(self, args: dict) -> str:
|
||||
content = str(args.get("content") or "").strip()
|
||||
if not content:
|
||||
return tool_error("content is required")
|
||||
try:
|
||||
tag = self._resolve_tool_container_tag(args)
|
||||
except ValueError as exc:
|
||||
return tool_error(str(exc))
|
||||
metadata = args.get("metadata") or {}
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
metadata.setdefault("type", _detect_category(content))
|
||||
metadata["source"] = "hermes_tool"
|
||||
try:
|
||||
result = self._client.add_memory(content, metadata=metadata, entity_context=self._entity_context)
|
||||
result = self._client.add_memory(content, metadata=metadata, entity_context=self._entity_context, container_tag=tag)
|
||||
preview = content[:80] + ("..." if len(content) > 80 else "")
|
||||
return json.dumps({"saved": True, "id": result.get("id", ""), "preview": preview})
|
||||
resp: dict[str, Any] = {"saved": True, "id": result.get("id", ""), "preview": preview}
|
||||
if tag:
|
||||
resp["container_tag"] = tag
|
||||
return json.dumps(resp)
|
||||
except Exception as exc:
|
||||
return tool_error(f"Failed to store memory: {exc}")
|
||||
|
||||
@@ -605,22 +706,29 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
query = str(args.get("query") or "").strip()
|
||||
if not query:
|
||||
return tool_error("query is required")
|
||||
try:
|
||||
tag = self._resolve_tool_container_tag(args)
|
||||
except ValueError as exc:
|
||||
return tool_error(str(exc))
|
||||
try:
|
||||
limit = max(1, min(20, int(args.get("limit", 5) or 5)))
|
||||
except Exception:
|
||||
limit = 5
|
||||
try:
|
||||
results = self._client.search_memories(query, limit=limit)
|
||||
results = self._client.search_memories(query, limit=limit, container_tag=tag)
|
||||
formatted = []
|
||||
for item in results:
|
||||
entry = {"id": item.get("id", ""), "content": item.get("memory", "")}
|
||||
entry: dict[str, Any] = {"id": item.get("id", ""), "content": item.get("memory", "")}
|
||||
if item.get("similarity") is not None:
|
||||
try:
|
||||
entry["similarity"] = round(float(item["similarity"]) * 100)
|
||||
except Exception:
|
||||
pass
|
||||
formatted.append(entry)
|
||||
return json.dumps({"results": formatted, "count": len(formatted)})
|
||||
resp: dict[str, Any] = {"results": formatted, "count": len(formatted)}
|
||||
if tag:
|
||||
resp["container_tag"] = tag
|
||||
return json.dumps(resp)
|
||||
except Exception as exc:
|
||||
return tool_error(f"Search failed: {exc}")
|
||||
|
||||
@@ -629,28 +737,39 @@ class SupermemoryMemoryProvider(MemoryProvider):
|
||||
query = str(args.get("query") or "").strip()
|
||||
if not memory_id and not query:
|
||||
return tool_error("Provide either id or query")
|
||||
try:
|
||||
tag = self._resolve_tool_container_tag(args)
|
||||
except ValueError as exc:
|
||||
return tool_error(str(exc))
|
||||
try:
|
||||
if memory_id:
|
||||
self._client.forget_memory(memory_id)
|
||||
self._client.forget_memory(memory_id, container_tag=tag)
|
||||
return json.dumps({"forgotten": True, "id": memory_id})
|
||||
return json.dumps(self._client.forget_by_query(query))
|
||||
return json.dumps(self._client.forget_by_query(query, container_tag=tag))
|
||||
except Exception as exc:
|
||||
return tool_error(f"Forget failed: {exc}")
|
||||
|
||||
def _tool_profile(self, args: dict) -> str:
|
||||
query = str(args.get("query") or "").strip() or None
|
||||
try:
|
||||
profile = self._client.get_profile(query=query)
|
||||
tag = self._resolve_tool_container_tag(args)
|
||||
except ValueError as exc:
|
||||
return tool_error(str(exc))
|
||||
try:
|
||||
profile = self._client.get_profile(query=query, container_tag=tag)
|
||||
sections = []
|
||||
if profile["static"]:
|
||||
sections.append("## User Profile (Persistent)\n" + "\n".join(f"- {item}" for item in profile["static"]))
|
||||
if profile["dynamic"]:
|
||||
sections.append("## Recent Context\n" + "\n".join(f"- {item}" for item in profile["dynamic"]))
|
||||
return json.dumps({
|
||||
resp: dict[str, Any] = {
|
||||
"profile": "\n\n".join(sections),
|
||||
"static_count": len(profile["static"]),
|
||||
"dynamic_count": len(profile["dynamic"]),
|
||||
})
|
||||
}
|
||||
if tag:
|
||||
resp["container_tag"] = tag
|
||||
return json.dumps(resp)
|
||||
except Exception as exc:
|
||||
return tool_error(f"Profile failed: {exc}")
|
||||
|
||||
|
||||
+24
-14
@@ -7391,20 +7391,30 @@ class AIAgent:
|
||||
response_invalid = True
|
||||
error_details.append("response.output is not a list")
|
||||
elif not output_items:
|
||||
# If we reach here, _run_codex_stream's backfill
|
||||
# from output_item.done events and text-delta
|
||||
# synthesis both failed to populate output.
|
||||
_resp_status = getattr(response, "status", None)
|
||||
_resp_incomplete = getattr(response, "incomplete_details", None)
|
||||
logging.warning(
|
||||
"Codex response.output is empty after stream backfill "
|
||||
"(status=%s, incomplete_details=%s, model=%s). %s",
|
||||
_resp_status, _resp_incomplete,
|
||||
getattr(response, "model", None),
|
||||
f"api_mode={self.api_mode} provider={self.provider}",
|
||||
)
|
||||
response_invalid = True
|
||||
error_details.append("response.output is empty")
|
||||
# Stream backfill may have failed, but
|
||||
# _normalize_codex_response can still recover
|
||||
# from response.output_text. Only mark invalid
|
||||
# when that fallback is also absent.
|
||||
_out_text = getattr(response, "output_text", None)
|
||||
_out_text_stripped = _out_text.strip() if isinstance(_out_text, str) else ""
|
||||
if _out_text_stripped:
|
||||
logger.debug(
|
||||
"Codex response.output is empty but output_text is present "
|
||||
"(%d chars); deferring to normalization.",
|
||||
len(_out_text_stripped),
|
||||
)
|
||||
else:
|
||||
_resp_status = getattr(response, "status", None)
|
||||
_resp_incomplete = getattr(response, "incomplete_details", None)
|
||||
logger.warning(
|
||||
"Codex response.output is empty after stream backfill "
|
||||
"(status=%s, incomplete_details=%s, model=%s). %s",
|
||||
_resp_status, _resp_incomplete,
|
||||
getattr(response, "model", None),
|
||||
f"api_mode={self.api_mode} provider={self.provider}",
|
||||
)
|
||||
response_invalid = True
|
||||
error_details.append("response.output is empty")
|
||||
elif self.api_mode == "anthropic_messages":
|
||||
content_blocks = getattr(response, "content", None) if response is not None else None
|
||||
if response is None:
|
||||
|
||||
@@ -13,7 +13,7 @@ from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
|
||||
def _run_auxiliary_bridge(config_dict, monkeypatch):
|
||||
@@ -199,7 +199,7 @@ class TestGatewayBridgeCodeParity:
|
||||
|
||||
def test_gateway_has_auxiliary_bridge(self):
|
||||
"""The gateway config bridge must include auxiliary.* bridging."""
|
||||
gateway_path = Path(__file__).parent.parent / "gateway" / "run.py"
|
||||
gateway_path = Path(__file__).parent.parent.parent / "gateway" / "run.py"
|
||||
content = gateway_path.read_text()
|
||||
# Check for key patterns that indicate the bridge is present
|
||||
assert "AUXILIARY_VISION_PROVIDER" in content
|
||||
@@ -213,7 +213,7 @@ class TestGatewayBridgeCodeParity:
|
||||
|
||||
def test_gateway_no_compression_env_bridge(self):
|
||||
"""Gateway should NOT bridge compression config to env vars (config-only)."""
|
||||
gateway_path = Path(__file__).parent.parent / "gateway" / "run.py"
|
||||
gateway_path = Path(__file__).parent.parent.parent / "gateway" / "run.py"
|
||||
content = gateway_path.read_text()
|
||||
assert "CONTEXT_COMPRESSION_PROVIDER" not in content
|
||||
assert "CONTEXT_COMPRESSION_MODEL" not in content
|
||||
@@ -330,7 +330,7 @@ def test_model_flow_nous_prints_subscription_guidance_without_mutating_explicit_
|
||||
"hermes_cli.auth.fetch_nous_models",
|
||||
lambda *args, **kwargs: ["claude-opus-4-6"],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None: "claude-opus-4-6")
|
||||
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None, **kw: "claude-opus-4-6")
|
||||
monkeypatch.setattr("hermes_cli.auth._save_model_choice", lambda model: None)
|
||||
monkeypatch.setattr("hermes_cli.auth._update_config_for_provider", lambda provider, url: None)
|
||||
monkeypatch.setattr(
|
||||
@@ -368,7 +368,7 @@ def test_model_flow_nous_applies_managed_tts_default_when_unconfigured(monkeypat
|
||||
"hermes_cli.auth.fetch_nous_models",
|
||||
lambda *args, **kwargs: ["claude-opus-4-6"],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None: "claude-opus-4-6")
|
||||
monkeypatch.setattr("hermes_cli.auth._prompt_model_selection", lambda model_ids, current_model="", pricing=None, **kw: "claude-opus-4-6")
|
||||
monkeypatch.setattr("hermes_cli.auth._save_model_choice", lambda model: None)
|
||||
monkeypatch.setattr("hermes_cli.auth._update_config_for_provider", lambda provider, url: None)
|
||||
monkeypatch.setattr(
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Regression tests for CLI /retry history replacement semantics."""
|
||||
|
||||
from tests.test_cli_init import _make_cli
|
||||
from tests.cli.test_cli_init import _make_cli
|
||||
|
||||
|
||||
def test_retry_last_truncates_history_before_requeueing_message():
|
||||
@@ -0,0 +1,164 @@
|
||||
"""Security tests for Terminal-Bench 2 archive extraction."""
|
||||
|
||||
import base64
|
||||
import importlib
|
||||
import io
|
||||
import sys
|
||||
import tarfile
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = types.ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
def _load_terminalbench_module(monkeypatch):
|
||||
class _EvalHandlingEnum:
|
||||
STOP_TRAIN = "stop_train"
|
||||
|
||||
class _APIServerConfig:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
class _AgentResult:
|
||||
pass
|
||||
|
||||
class _HermesAgentLoop:
|
||||
pass
|
||||
|
||||
class _HermesAgentBaseEnv:
|
||||
pass
|
||||
|
||||
class _HermesAgentEnvConfig:
|
||||
pass
|
||||
|
||||
class _ToolContext:
|
||||
pass
|
||||
|
||||
stub_modules = {
|
||||
"atroposlib": _stub_module("atroposlib"),
|
||||
"atroposlib.envs": _stub_module("atroposlib.envs"),
|
||||
"atroposlib.envs.base": _stub_module(
|
||||
"atroposlib.envs.base",
|
||||
EvalHandlingEnum=_EvalHandlingEnum,
|
||||
),
|
||||
"atroposlib.envs.server_handling": _stub_module("atroposlib.envs.server_handling"),
|
||||
"atroposlib.envs.server_handling.server_manager": _stub_module(
|
||||
"atroposlib.envs.server_handling.server_manager",
|
||||
APIServerConfig=_APIServerConfig,
|
||||
),
|
||||
"environments.agent_loop": _stub_module(
|
||||
"environments.agent_loop",
|
||||
AgentResult=_AgentResult,
|
||||
HermesAgentLoop=_HermesAgentLoop,
|
||||
),
|
||||
"environments.hermes_base_env": _stub_module(
|
||||
"environments.hermes_base_env",
|
||||
HermesAgentBaseEnv=_HermesAgentBaseEnv,
|
||||
HermesAgentEnvConfig=_HermesAgentEnvConfig,
|
||||
),
|
||||
"environments.tool_context": _stub_module(
|
||||
"environments.tool_context",
|
||||
ToolContext=_ToolContext,
|
||||
),
|
||||
"tools.terminal_tool": _stub_module(
|
||||
"tools.terminal_tool",
|
||||
register_task_env_overrides=lambda *args, **kwargs: None,
|
||||
clear_task_env_overrides=lambda *args, **kwargs: None,
|
||||
cleanup_vm=lambda *args, **kwargs: None,
|
||||
),
|
||||
}
|
||||
|
||||
stub_modules["atroposlib"].envs = stub_modules["atroposlib.envs"]
|
||||
stub_modules["atroposlib.envs"].base = stub_modules["atroposlib.envs.base"]
|
||||
stub_modules["atroposlib.envs"].server_handling = stub_modules["atroposlib.envs.server_handling"]
|
||||
stub_modules["atroposlib.envs.server_handling"].server_manager = stub_modules[
|
||||
"atroposlib.envs.server_handling.server_manager"
|
||||
]
|
||||
|
||||
for name, module in stub_modules.items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
module_name = "environments.benchmarks.terminalbench_2.terminalbench2_env"
|
||||
sys.modules.pop(module_name, None)
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
def _build_tar_b64(entries):
|
||||
buf = io.BytesIO()
|
||||
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||
for entry in entries:
|
||||
kind = entry["kind"]
|
||||
info = tarfile.TarInfo(entry["name"])
|
||||
|
||||
if kind == "dir":
|
||||
info.type = tarfile.DIRTYPE
|
||||
tar.addfile(info)
|
||||
continue
|
||||
|
||||
if kind == "file":
|
||||
data = entry["data"].encode("utf-8")
|
||||
info.size = len(data)
|
||||
tar.addfile(info, io.BytesIO(data))
|
||||
continue
|
||||
|
||||
if kind == "symlink":
|
||||
info.type = tarfile.SYMTYPE
|
||||
info.linkname = entry["target"]
|
||||
tar.addfile(info)
|
||||
continue
|
||||
|
||||
raise ValueError(f"Unknown tar entry kind: {kind}")
|
||||
|
||||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
|
||||
def test_extract_base64_tar_allows_safe_files(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "dir", "name": "nested"},
|
||||
{"kind": "file", "name": "nested/hello.txt", "data": "hello"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert (target / "nested" / "hello.txt").read_text(encoding="utf-8") == "hello"
|
||||
|
||||
|
||||
def test_extract_base64_tar_rejects_path_traversal(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "file", "name": "../escape.txt", "data": "owned"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
with pytest.raises(ValueError, match="Unsafe archive member path"):
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert not (tmp_path / "escape.txt").exists()
|
||||
|
||||
|
||||
def test_extract_base64_tar_rejects_symlinks(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "symlink", "name": "link", "target": "../../escape.txt"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
with pytest.raises(ValueError, match="Unsupported archive member type"):
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert not (target / "link").exists()
|
||||
@@ -504,7 +504,8 @@ class TestMattermostFileUpload:
|
||||
self.adapter._session = MagicMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_downloads_and_uploads(self):
|
||||
@patch("tools.url_safety.is_safe_url", return_value=True)
|
||||
async def test_send_image_downloads_and_uploads(self, _mock_safe):
|
||||
"""send_image should download the URL, upload via /api/v4/files, then post."""
|
||||
# Mock the download (GET)
|
||||
mock_dl_resp = AsyncMock()
|
||||
|
||||
@@ -596,10 +596,11 @@ def _make_aiohttp_resp(status: int, content: bytes = b"file bytes",
|
||||
return resp
|
||||
|
||||
|
||||
@patch("tools.url_safety.is_safe_url", return_value=True)
|
||||
class TestMattermostSendUrlAsFile:
|
||||
"""Tests for MattermostAdapter._send_url_as_file"""
|
||||
|
||||
def test_success_on_first_attempt(self):
|
||||
def test_success_on_first_attempt(self, _mock_safe):
|
||||
"""200 on first attempt → file uploaded and post created."""
|
||||
adapter = _make_mm_adapter()
|
||||
resp = _make_aiohttp_resp(200)
|
||||
@@ -616,7 +617,7 @@ class TestMattermostSendUrlAsFile:
|
||||
adapter._upload_file.assert_called_once()
|
||||
adapter._api_post.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self):
|
||||
def test_retries_on_429_then_succeeds(self, _mock_safe):
|
||||
"""429 on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
@@ -637,7 +638,7 @@ class TestMattermostSendUrlAsFile:
|
||||
assert adapter._session.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_500_then_succeeds(self):
|
||||
def test_retries_on_500_then_succeeds(self, _mock_safe):
|
||||
"""5xx on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
@@ -655,7 +656,7 @@ class TestMattermostSendUrlAsFile:
|
||||
assert result.success
|
||||
assert adapter._session.get.call_count == 2
|
||||
|
||||
def test_falls_back_to_text_after_max_retries_on_5xx(self):
|
||||
def test_falls_back_to_text_after_max_retries_on_5xx(self, _mock_safe):
|
||||
"""Three consecutive 500s exhaust retries; falls back to send() with URL text."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
@@ -674,7 +675,7 @@ class TestMattermostSendUrlAsFile:
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_falls_back_on_client_error(self):
|
||||
def test_falls_back_on_client_error(self, _mock_safe):
|
||||
"""aiohttp.ClientError on every attempt falls back to send() with URL."""
|
||||
import aiohttp
|
||||
|
||||
@@ -699,7 +700,7 @@ class TestMattermostSendUrlAsFile:
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_non_retryable_404_falls_back_immediately(self):
|
||||
def test_non_retryable_404_falls_back_immediately(self, _mock_safe):
|
||||
"""404 is non-retryable (< 500, != 429); send() is called right away."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
|
||||
@@ -71,6 +71,24 @@ class FakeAgent:
|
||||
}
|
||||
|
||||
|
||||
class LongPreviewAgent:
|
||||
"""Agent that emits a tool call with a very long preview string."""
|
||||
LONG_CMD = "cd /home/teknium/.hermes/hermes-agent/.worktrees/hermes-d8860339 && source .venv/bin/activate && python -m pytest tests/gateway/test_run_progress_topics.py -n0 -q"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback("tool.started", "terminal", self.LONG_CMD, {})
|
||||
time.sleep(0.35)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_runner(adapter):
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
GatewayRunner = gateway_run.GatewayRunner
|
||||
@@ -217,3 +235,102 @@ async def test_run_agent_progress_uses_event_message_id_for_slack_dm(monkeypatch
|
||||
assert adapter.sent
|
||||
assert adapter.sent[0]["metadata"] == {"thread_id": "1234567890.000001"}
|
||||
assert all(call["metadata"] == {"thread_id": "1234567890.000001"} for call in adapter.typing)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview truncation tests (all/new mode respects tool_preview_length)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_long_preview_helper(monkeypatch, tmp_path, preview_length=0):
|
||||
"""Shared setup for long-preview truncation tests.
|
||||
|
||||
Returns (adapter, result) after running the agent with LongPreviewAgent.
|
||||
``preview_length`` controls display.tool_preview_length in the config file
|
||||
that _run_agent reads — so the gateway picks it up the same way production does.
|
||||
"""
|
||||
import asyncio
|
||||
import yaml
|
||||
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all")
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = LongPreviewAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
# Write config.yaml so _run_agent picks up tool_preview_length
|
||||
config = {"display": {"tool_preview_length": preview_length}}
|
||||
(tmp_path / "config.yaml").write_text(yaml.dump(config), encoding="utf-8")
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-trunc",
|
||||
session_key="agent:main:telegram:dm:12345",
|
||||
)
|
||||
)
|
||||
return adapter, result
|
||||
|
||||
|
||||
def test_all_mode_default_truncation_40_chars(monkeypatch, tmp_path):
|
||||
"""When tool_preview_length is 0 (default), all/new mode truncates to 40 chars."""
|
||||
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=0)
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
content = adapter.sent[0]["content"]
|
||||
# The long command should be truncated — total preview <= 40 chars
|
||||
assert "..." in content
|
||||
# Extract the preview part between quotes
|
||||
import re
|
||||
match = re.search(r'"(.+)"', content)
|
||||
assert match, f"No quoted preview found in: {content}"
|
||||
preview_text = match.group(1)
|
||||
assert len(preview_text) <= 40, f"Preview too long ({len(preview_text)}): {preview_text}"
|
||||
|
||||
|
||||
def test_all_mode_respects_custom_preview_length(monkeypatch, tmp_path):
|
||||
"""When tool_preview_length is explicitly set (e.g. 120), all/new mode uses that."""
|
||||
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=120)
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
content = adapter.sent[0]["content"]
|
||||
# With 120-char cap, the command (165 chars) should still be truncated but longer
|
||||
import re
|
||||
match = re.search(r'"(.+)"', content)
|
||||
assert match, f"No quoted preview found in: {content}"
|
||||
preview_text = match.group(1)
|
||||
# Should be longer than the 40-char default
|
||||
assert len(preview_text) > 40, f"Preview suspiciously short ({len(preview_text)}): {preview_text}"
|
||||
# But still capped at 120
|
||||
assert len(preview_text) <= 120, f"Preview too long ({len(preview_text)}): {preview_text}"
|
||||
|
||||
|
||||
def test_all_mode_no_truncation_when_preview_fits(monkeypatch, tmp_path):
|
||||
"""Short previews (under the cap) are not truncated."""
|
||||
# Set a generous cap — the LongPreviewAgent's command is ~165 chars
|
||||
adapter, result = _run_long_preview_helper(monkeypatch, tmp_path, preview_length=200)
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
content = adapter.sent[0]["content"]
|
||||
# With a 200-char cap, the 165-char command should NOT be truncated
|
||||
assert "..." not in content, f"Preview was truncated when it shouldn't be: {content}"
|
||||
|
||||
@@ -51,7 +51,8 @@ def _make_runner(session_entry: SessionEntry):
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session_title.return_value = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
@@ -82,12 +83,34 @@ async def test_status_command_reports_running_agent_without_interrupt(monkeypatc
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
assert "**Session ID:** `sess-1`" in result
|
||||
assert "**Tokens:** 321" in result
|
||||
assert "**Agent Running:** Yes ⚡" in result
|
||||
assert "**Title:**" not in result
|
||||
running_agent.interrupt.assert_not_called()
|
||||
assert runner._pending_messages == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_includes_session_title_when_present():
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=321,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session_title.return_value = "My titled session"
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
assert "**Session ID:** `sess-1`" in result
|
||||
assert "**Title:** My titled session" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
@@ -33,8 +33,15 @@ def _ensure_telegram_mock():
|
||||
mod.constants.ChatType.GROUP = "group"
|
||||
mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
mod.constants.ChatType.CHANNEL = "channel"
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request", "telegram.error"):
|
||||
# Provide real exception classes so ``except (NetworkError, ...)`` in
|
||||
# connect() doesn't blow up under xdist when this mock leaks.
|
||||
mod.error.NetworkError = type("NetworkError", (OSError,), {})
|
||||
mod.error.TimedOut = type("TimedOut", (OSError,), {})
|
||||
mod.error.BadRequest = type("BadRequest", (Exception,), {})
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, mod)
|
||||
sys.modules.setdefault("telegram.error", mod.error)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Tests for TelegramPlatform._merge_caption caption deduplication logic."""
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
merge = TelegramAdapter._merge_caption
|
||||
|
||||
|
||||
class TestMergeCaptionBasic:
|
||||
def test_no_existing_text(self):
|
||||
assert merge(None, "Hello") == "Hello"
|
||||
|
||||
def test_empty_existing_text(self):
|
||||
assert merge("", "Hello") == "Hello"
|
||||
|
||||
def test_exact_duplicate_dropped(self):
|
||||
assert merge("Revenue", "Revenue") == "Revenue"
|
||||
|
||||
def test_different_captions_merged(self):
|
||||
result = merge("Q3 Results", "Q4 Projections")
|
||||
assert result == "Q3 Results\n\nQ4 Projections"
|
||||
|
||||
|
||||
class TestMergeCaptionSubstringBug:
|
||||
"""These are the exact scenarios that the old substring check got wrong."""
|
||||
|
||||
def test_shorter_caption_not_dropped_when_substring(self):
|
||||
# Bug: "Meeting" in "Meeting agenda" → True → caption was silently lost
|
||||
result = merge("Meeting agenda", "Meeting")
|
||||
assert result == "Meeting agenda\n\nMeeting"
|
||||
|
||||
def test_longer_caption_not_dropped_when_contains_existing(self):
|
||||
# "Revenue and Profit" contains "Revenue", but they are different captions
|
||||
result = merge("Revenue", "Revenue and Profit")
|
||||
assert result == "Revenue\n\nRevenue and Profit"
|
||||
|
||||
def test_prefix_caption_not_dropped(self):
|
||||
result = merge("Q3 Results - Revenue", "Q3 Results")
|
||||
assert result == "Q3 Results - Revenue\n\nQ3 Results"
|
||||
|
||||
|
||||
class TestMergeCaptionWhitespace:
|
||||
def test_trailing_space_treated_as_duplicate(self):
|
||||
assert merge("Revenue", "Revenue ") == "Revenue"
|
||||
|
||||
def test_leading_space_treated_as_duplicate(self):
|
||||
assert merge("Revenue", " Revenue") == "Revenue"
|
||||
|
||||
def test_whitespace_only_new_text_not_added(self):
|
||||
# strip() makes it empty string → falsy check in callers guards this,
|
||||
# but _merge_caption itself: strip matches "" which is not in list → would merge.
|
||||
# Callers already guard with `if event.text:` so this is an edge case.
|
||||
result = merge("Revenue", " ")
|
||||
# " ".strip() == "" → not in ["Revenue"] → gets merged (caller guards prevent this)
|
||||
assert "\n\n" in result or result == "Revenue"
|
||||
|
||||
|
||||
class TestMergeCaptionMultipleItems:
|
||||
def test_three_unique_captions_all_present(self):
|
||||
text = merge(None, "A")
|
||||
text = merge(text, "B")
|
||||
text = merge(text, "C")
|
||||
assert text == "A\n\nB\n\nC"
|
||||
|
||||
def test_duplicate_in_middle_dropped(self):
|
||||
text = merge(None, "A")
|
||||
text = merge(text, "B")
|
||||
text = merge(text, "A") # duplicate
|
||||
assert text == "A\n\nB"
|
||||
|
||||
def test_album_scenario_revenue_profit(self):
|
||||
# Album Item 1: "Revenue and Profit", Item 2: "Revenue"
|
||||
# Old bug: "Revenue" in ["Revenue and Profit"] → True → lost
|
||||
text = merge(None, "Revenue and Profit")
|
||||
text = merge(text, "Revenue")
|
||||
assert text == "Revenue and Profit\n\nRevenue"
|
||||
@@ -20,8 +20,16 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
# Provide real exception classes so ``except (NetworkError, ...)`` in
|
||||
# connect() doesn't blow up with "catching classes that do not inherit
|
||||
# from BaseException" when another xdist worker pollutes sys.modules.
|
||||
telegram_mod.error.NetworkError = type("NetworkError", (OSError,), {})
|
||||
telegram_mod.error.TimedOut = type("TimedOut", (OSError,), {})
|
||||
telegram_mod.error.BadRequest = type("BadRequest", (Exception,), {})
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
sys.modules.setdefault("telegram.error", telegram_mod.error)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
@@ -590,8 +590,15 @@ class TestSessionIsolation:
|
||||
class TestDeliveryCleanup:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_info_cleaned_after_send(self):
|
||||
"""send() pops delivery_info so the entry doesn't leak memory."""
|
||||
async def test_delivery_info_survives_multiple_sends(self):
|
||||
"""send() must NOT pop delivery_info.
|
||||
|
||||
Interim status messages (fallback notifications, context-pressure
|
||||
warnings, etc.) flow through the same send() path as the final
|
||||
response. If the entry were popped on the first send, the final
|
||||
response would silently downgrade to the ``log`` deliver type.
|
||||
Regression test for that bug.
|
||||
"""
|
||||
adapter = _make_adapter()
|
||||
chat_id = "webhook:test:d-xyz"
|
||||
adapter._delivery_info[chat_id] = {
|
||||
@@ -599,10 +606,40 @@ class TestDeliveryCleanup:
|
||||
"deliver_extra": {},
|
||||
"payload": {"x": 1},
|
||||
}
|
||||
adapter._delivery_info_created[chat_id] = time.time()
|
||||
|
||||
result = await adapter.send(chat_id, "Agent response here")
|
||||
assert result.success is True
|
||||
assert chat_id not in adapter._delivery_info
|
||||
# First send (e.g. an interim status message)
|
||||
result1 = await adapter.send(chat_id, "Status: switching to fallback")
|
||||
assert result1.success is True
|
||||
# Entry must still be present so the final send can read it
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
# Second send (the final agent response)
|
||||
result2 = await adapter.send(chat_id, "Final agent response")
|
||||
assert result2.success is True
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delivery_info_pruned_via_ttl(self):
|
||||
"""Stale delivery_info entries are dropped on the next POST."""
|
||||
adapter = _make_adapter()
|
||||
adapter._idempotency_ttl = 60 # short TTL for the test
|
||||
now = time.time()
|
||||
|
||||
# Stale entry — older than TTL
|
||||
adapter._delivery_info["webhook:test:old"] = {"deliver": "log"}
|
||||
adapter._delivery_info_created["webhook:test:old"] = now - 120
|
||||
|
||||
# Fresh entry — should survive
|
||||
adapter._delivery_info["webhook:test:new"] = {"deliver": "log"}
|
||||
adapter._delivery_info_created["webhook:test:new"] = now - 5
|
||||
|
||||
adapter._prune_delivery_info(now)
|
||||
|
||||
assert "webhook:test:old" not in adapter._delivery_info
|
||||
assert "webhook:test:old" not in adapter._delivery_info_created
|
||||
assert "webhook:test:new" in adapter._delivery_info
|
||||
assert "webhook:test:new" in adapter._delivery_info_created
|
||||
|
||||
|
||||
# ===================================================================
|
||||
|
||||
@@ -259,8 +259,9 @@ class TestCrossPlatformDelivery:
|
||||
mock_tg_adapter.send.assert_awaited_once_with(
|
||||
"12345", "I've acknowledged the alert.", metadata=None
|
||||
)
|
||||
# Delivery info should be cleaned up
|
||||
assert chat_id not in adapter._delivery_info
|
||||
# Delivery info is retained after send() so interim status messages
|
||||
# don't strand the final response (TTL-based cleanup happens on POST).
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
|
||||
# ===================================================================
|
||||
@@ -333,5 +334,6 @@ class TestGitHubCommentDelivery:
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
# Delivery info cleaned up
|
||||
assert chat_id not in adapter._delivery_info
|
||||
# Delivery info is retained after send() so interim status messages
|
||||
# don't strand the final response (TTL-based cleanup happens on POST).
|
||||
assert chat_id in adapter._delivery_info
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_version_string_no_v_prefix():
|
||||
assert not __version__.startswith("v"), f"__version__ should not start with 'v', got {__version__!r}"
|
||||
|
||||
|
||||
def test_check_for_updates_uses_cache(tmp_path):
|
||||
def test_check_for_updates_uses_cache(tmp_path, monkeypatch):
|
||||
"""When cache is fresh, check_for_updates should return cached value without calling git."""
|
||||
from hermes_cli.banner import check_for_updates
|
||||
|
||||
@@ -27,15 +27,15 @@ def test_check_for_updates_uses_cache(tmp_path):
|
||||
cache_file = tmp_path / ".update_check"
|
||||
cache_file.write_text(json.dumps({"ts": time.time(), "behind": 3}))
|
||||
|
||||
with patch("hermes_cli.banner.os.getenv", return_value=str(tmp_path)):
|
||||
with patch("hermes_cli.banner.subprocess.run") as mock_run:
|
||||
result = check_for_updates()
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
with patch("hermes_cli.banner.subprocess.run") as mock_run:
|
||||
result = check_for_updates()
|
||||
|
||||
assert result == 3
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
def test_check_for_updates_expired_cache(tmp_path):
|
||||
def test_check_for_updates_expired_cache(tmp_path, monkeypatch):
|
||||
"""When cache is expired, check_for_updates should call git fetch."""
|
||||
from hermes_cli.banner import check_for_updates
|
||||
|
||||
@@ -49,15 +49,15 @@ def test_check_for_updates_expired_cache(tmp_path):
|
||||
|
||||
mock_result = MagicMock(returncode=0, stdout="5\n")
|
||||
|
||||
with patch("hermes_cli.banner.os.getenv", return_value=str(tmp_path)):
|
||||
with patch("hermes_cli.banner.subprocess.run", return_value=mock_result) as mock_run:
|
||||
result = check_for_updates()
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
with patch("hermes_cli.banner.subprocess.run", return_value=mock_result) as mock_run:
|
||||
result = check_for_updates()
|
||||
|
||||
assert result == 5
|
||||
assert mock_run.call_count == 2 # git fetch + git rev-list
|
||||
|
||||
|
||||
def test_check_for_updates_no_git_dir(tmp_path):
|
||||
def test_check_for_updates_no_git_dir(tmp_path, monkeypatch):
|
||||
"""Returns None when .git directory doesn't exist anywhere."""
|
||||
import hermes_cli.banner as banner
|
||||
|
||||
@@ -66,19 +66,15 @@ def test_check_for_updates_no_git_dir(tmp_path):
|
||||
fake_banner.parent.mkdir(parents=True, exist_ok=True)
|
||||
fake_banner.touch()
|
||||
|
||||
original = banner.__file__
|
||||
try:
|
||||
banner.__file__ = str(fake_banner)
|
||||
with patch("hermes_cli.banner.os.getenv", return_value=str(tmp_path)):
|
||||
with patch("hermes_cli.banner.subprocess.run") as mock_run:
|
||||
result = banner.check_for_updates()
|
||||
assert result is None
|
||||
mock_run.assert_not_called()
|
||||
finally:
|
||||
banner.__file__ = original
|
||||
monkeypatch.setattr(banner, "__file__", str(fake_banner))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
with patch("hermes_cli.banner.subprocess.run") as mock_run:
|
||||
result = banner.check_for_updates()
|
||||
assert result is None
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
def test_check_for_updates_fallback_to_project_root():
|
||||
def test_check_for_updates_fallback_to_project_root(tmp_path, monkeypatch):
|
||||
"""Dev install: falls back to Path(__file__).parent.parent when HERMES_HOME has no git repo."""
|
||||
import hermes_cli.banner as banner
|
||||
|
||||
@@ -87,14 +83,12 @@ def test_check_for_updates_fallback_to_project_root():
|
||||
pytest.skip("Not running from a git checkout")
|
||||
|
||||
# Point HERMES_HOME at a temp dir with no hermes-agent/.git
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
with patch("hermes_cli.banner.os.getenv", return_value=td):
|
||||
with patch("hermes_cli.banner.subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="0\n")
|
||||
result = banner.check_for_updates()
|
||||
# Should have fallen back to project root and run git commands
|
||||
assert mock_run.call_count >= 1
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
with patch("hermes_cli.banner.subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="0\n")
|
||||
result = banner.check_for_updates()
|
||||
# Should have fallen back to project root and run git commands
|
||||
assert mock_run.call_count >= 1
|
||||
|
||||
|
||||
def test_prefetch_non_blocking():
|
||||
|
||||
@@ -13,10 +13,11 @@ from plugins.memory.supermemory import (
|
||||
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, api_key: str, timeout: float, container_tag: str):
|
||||
def __init__(self, api_key: str, timeout: float, container_tag: str, search_mode: str = "hybrid"):
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.container_tag = container_tag
|
||||
self.search_mode = search_mode
|
||||
self.add_calls = []
|
||||
self.search_results = []
|
||||
self.profile_response = {"static": [], "dynamic": [], "search_results": []}
|
||||
@@ -24,24 +25,27 @@ class FakeClient:
|
||||
self.forgotten_ids = []
|
||||
self.forget_by_query_response = {"success": True, "message": "Forgot"}
|
||||
|
||||
def add_memory(self, content, metadata=None, *, entity_context=""):
|
||||
def add_memory(self, content, metadata=None, *, entity_context="",
|
||||
container_tag=None, custom_id=None):
|
||||
self.add_calls.append({
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
"entity_context": entity_context,
|
||||
"container_tag": container_tag,
|
||||
"custom_id": custom_id,
|
||||
})
|
||||
return {"id": "mem_123"}
|
||||
|
||||
def search_memories(self, query, *, limit=5):
|
||||
def search_memories(self, query, *, limit=5, container_tag=None, search_mode=None):
|
||||
return self.search_results
|
||||
|
||||
def get_profile(self, query=None):
|
||||
def get_profile(self, query=None, *, container_tag=None):
|
||||
return self.profile_response
|
||||
|
||||
def forget_memory(self, memory_id):
|
||||
def forget_memory(self, memory_id, *, container_tag=None):
|
||||
self.forgotten_ids.append(memory_id)
|
||||
|
||||
def forget_by_query(self, query):
|
||||
def forget_by_query(self, query, *, container_tag=None):
|
||||
return self.forget_by_query_response
|
||||
|
||||
def ingest_conversation(self, session_id, messages):
|
||||
@@ -82,7 +86,8 @@ def test_is_available_false_when_import_missing(monkeypatch):
|
||||
def test_load_and_save_config_round_trip(tmp_path):
|
||||
_save_supermemory_config({"container_tag": "demo-tag", "auto_capture": False}, str(tmp_path))
|
||||
cfg = _load_supermemory_config(str(tmp_path))
|
||||
assert cfg["container_tag"] == "demo_tag"
|
||||
# container_tag is kept raw — sanitization happens in initialize() after template resolution
|
||||
assert cfg["container_tag"] == "demo-tag"
|
||||
assert cfg["auto_capture"] is False
|
||||
assert cfg["auto_recall"] is True
|
||||
|
||||
@@ -176,7 +181,8 @@ def test_shutdown_joins_and_clears_threads(provider, monkeypatch):
|
||||
started = threading.Event()
|
||||
release = threading.Event()
|
||||
|
||||
def slow_add_memory(content, metadata=None, *, entity_context=""):
|
||||
def slow_add_memory(content, metadata=None, *, entity_context="",
|
||||
container_tag=None, custom_id=None):
|
||||
started.set()
|
||||
release.wait(timeout=1)
|
||||
provider._client.add_calls.append({
|
||||
@@ -255,3 +261,151 @@ def test_handle_tool_call_returns_error_when_unconfigured(monkeypatch):
|
||||
p = SupermemoryMemoryProvider()
|
||||
result = json.loads(p.handle_tool_call("supermemory_search", {"query": "x"}))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
# -- Identity template tests --------------------------------------------------
|
||||
|
||||
|
||||
def test_identity_template_resolved_in_container_tag(monkeypatch, tmp_path):
|
||||
"""container_tag with {identity} resolves to profile-scoped tag."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({"container_tag": "hermes-{identity}"}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli", agent_identity="coder")
|
||||
assert p._container_tag == "hermes_coder"
|
||||
|
||||
|
||||
def test_identity_template_default_profile(monkeypatch, tmp_path):
|
||||
"""Without agent_identity kwarg, {identity} resolves to 'default'."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({"container_tag": "hermes-{identity}"}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
assert p._container_tag == "hermes_default"
|
||||
|
||||
|
||||
def test_container_tag_env_var_override(monkeypatch, tmp_path):
|
||||
"""SUPERMEMORY_CONTAINER_TAG env var overrides config."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setenv("SUPERMEMORY_CONTAINER_TAG", "env-override")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
assert p._container_tag == "env_override"
|
||||
|
||||
|
||||
# -- Search mode tests --------------------------------------------------------
|
||||
|
||||
|
||||
def test_search_mode_config_passed_to_client(monkeypatch, tmp_path):
|
||||
"""search_mode from config is passed to _SupermemoryClient."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({"search_mode": "memories"}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
assert p._search_mode == "memories"
|
||||
assert p._client.search_mode == "memories"
|
||||
|
||||
|
||||
def test_invalid_search_mode_falls_back_to_default(monkeypatch, tmp_path):
|
||||
"""Invalid search_mode falls back to 'hybrid'."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({"search_mode": "invalid_mode"}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
assert p._search_mode == "hybrid"
|
||||
|
||||
|
||||
# -- Multi-container tests ----------------------------------------------------
|
||||
|
||||
|
||||
def test_multi_container_disabled_by_default(provider):
|
||||
"""Multi-container is off by default; schemas have no container_tag param."""
|
||||
assert provider._enable_custom_containers is False
|
||||
schemas = provider.get_tool_schemas()
|
||||
for s in schemas:
|
||||
assert "container_tag" not in s["parameters"]["properties"]
|
||||
|
||||
|
||||
def test_multi_container_enabled_adds_schema_param(monkeypatch, tmp_path):
|
||||
"""When enabled, tool schemas include container_tag parameter."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({
|
||||
"enable_custom_container_tags": True,
|
||||
"custom_containers": ["project-alpha", "shared"],
|
||||
}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
assert p._enable_custom_containers is True
|
||||
assert p._allowed_containers == ["hermes", "project_alpha", "shared"]
|
||||
schemas = p.get_tool_schemas()
|
||||
for s in schemas:
|
||||
assert "container_tag" in s["parameters"]["properties"]
|
||||
|
||||
|
||||
def test_multi_container_tool_store_with_custom_tag(monkeypatch, tmp_path):
|
||||
"""supermemory_store uses the resolved container_tag when multi-container is enabled."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({
|
||||
"enable_custom_container_tags": True,
|
||||
"custom_containers": ["project-alpha"],
|
||||
}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
result = json.loads(p.handle_tool_call("supermemory_store", {
|
||||
"content": "test memory",
|
||||
"container_tag": "project-alpha",
|
||||
}))
|
||||
assert result["saved"] is True
|
||||
assert result["container_tag"] == "project_alpha"
|
||||
assert p._client.add_calls[-1]["container_tag"] == "project_alpha"
|
||||
|
||||
|
||||
def test_multi_container_rejects_unlisted_tag(monkeypatch, tmp_path):
|
||||
"""Tool calls with a non-whitelisted container_tag return an error."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({
|
||||
"enable_custom_container_tags": True,
|
||||
"custom_containers": ["allowed-tag"],
|
||||
}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
result = json.loads(p.handle_tool_call("supermemory_store", {
|
||||
"content": "test",
|
||||
"container_tag": "forbidden-tag",
|
||||
}))
|
||||
assert "error" in result
|
||||
assert "not allowed" in result["error"]
|
||||
|
||||
|
||||
def test_multi_container_system_prompt_includes_instructions(monkeypatch, tmp_path):
|
||||
"""system_prompt_block includes container list and instructions when multi-container is enabled."""
|
||||
monkeypatch.setenv("SUPERMEMORY_API_KEY", "test-key")
|
||||
monkeypatch.setattr("plugins.memory.supermemory._SupermemoryClient", FakeClient)
|
||||
_save_supermemory_config({
|
||||
"enable_custom_container_tags": True,
|
||||
"custom_containers": ["docs"],
|
||||
"custom_container_instructions": "Use docs for documentation context.",
|
||||
}, str(tmp_path))
|
||||
p = SupermemoryMemoryProvider()
|
||||
p.initialize("s1", hermes_home=str(tmp_path), platform="cli")
|
||||
block = p.system_prompt_block()
|
||||
assert "Multi-container mode enabled" in block
|
||||
assert "docs" in block
|
||||
assert "Use docs for documentation context." in block
|
||||
|
||||
|
||||
def test_get_config_schema_minimal():
|
||||
"""get_config_schema only returns the API key field."""
|
||||
p = SupermemoryMemoryProvider()
|
||||
schema = p.get_config_schema()
|
||||
assert len(schema) == 1
|
||||
assert schema[0]["key"] == "api_key"
|
||||
assert schema[0]["secret"] is True
|
||||
|
||||
@@ -16,7 +16,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
# Ensure repo root is importable
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
||||
|
||||
try:
|
||||
from environments.agent_loop import (
|
||||
+1
-1
@@ -31,7 +31,7 @@ import pytest
|
||||
# pytestmark removed — tests skip gracefully via OPENROUTER_API_KEY check on line 59
|
||||
|
||||
# Ensure repo root is importable
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
@@ -30,7 +30,7 @@ import pytest
|
||||
import requests
|
||||
|
||||
# Ensure repo root is importable
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user