Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b0b9ef0c86 |
@@ -16,8 +16,13 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: test (${{ matrix.group }}/4)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
group: [1, 2, 3, 4]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
@@ -37,10 +42,11 @@ jobs:
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all,dev]"
|
||||
|
||||
- name: Run tests
|
||||
- name: Run tests (shard ${{ matrix.group }}/4)
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python -m pytest tests/ -q --ignore=tests/integration --ignore=tests/e2e --tb=short -n auto
|
||||
python -m pytest tests/ -q --ignore=tests/integration --ignore=tests/e2e --tb=short \
|
||||
--splits 4 --group ${{ matrix.group }}
|
||||
env:
|
||||
# Ensure tests don't accidentally call real APIs
|
||||
OPENROUTER_API_KEY: ""
|
||||
|
||||
@@ -458,45 +458,13 @@ def profile_env(tmp_path, monkeypatch):
|
||||
|
||||
## Testing
|
||||
|
||||
**ALWAYS use `scripts/run_tests.sh`** — do not call `pytest` directly. The script enforces
|
||||
hermetic environment parity with CI (unset credential vars, TZ=UTC, LANG=C.UTF-8,
|
||||
4 xdist workers matching GHA ubuntu-latest). Direct `pytest` on a 16+ core
|
||||
developer machine with API keys set diverges from CI in ways that have caused
|
||||
multiple "works locally, fails in CI" incidents (and the reverse).
|
||||
|
||||
```bash
|
||||
scripts/run_tests.sh # full suite, CI-parity
|
||||
scripts/run_tests.sh tests/gateway/ # one directory
|
||||
scripts/run_tests.sh tests/agent/test_foo.py::test_x # one test
|
||||
scripts/run_tests.sh -v --tb=long # pass-through pytest flags
|
||||
```
|
||||
|
||||
### Why the wrapper (and why the old "just call pytest" doesn't work)
|
||||
|
||||
Five real sources of local-vs-CI drift the script closes:
|
||||
|
||||
| | Without wrapper | With wrapper |
|
||||
|---|---|---|
|
||||
| Provider API keys | Whatever is in your env (auto-detects pool) | All `*_API_KEY`/`*_TOKEN`/etc. unset |
|
||||
| HOME / `~/.hermes/` | Your real config+auth.json | Temp dir per test |
|
||||
| Timezone | Local TZ (PDT etc.) | UTC |
|
||||
| Locale | Whatever is set | C.UTF-8 |
|
||||
| xdist workers | `-n auto` = all cores (20+ on a workstation) | `-n 4` matching CI |
|
||||
|
||||
`tests/conftest.py` also enforces points 1-4 as an autouse fixture so ANY pytest
|
||||
invocation (including IDE integrations) gets hermetic behavior — but the wrapper
|
||||
is belt-and-suspenders.
|
||||
|
||||
### Running without the wrapper (only if you must)
|
||||
|
||||
If you can't use the wrapper (e.g. on Windows or inside an IDE that shells
|
||||
pytest directly), at minimum activate the venv and pass `-n 4`:
|
||||
|
||||
```bash
|
||||
source venv/bin/activate
|
||||
python -m pytest tests/ -q -n 4
|
||||
python -m pytest tests/ -q # Full suite (~3000 tests, ~3 min)
|
||||
python -m pytest tests/test_model_tools.py -q # Toolset resolution
|
||||
python -m pytest tests/test_cli_init.py -q # CLI config loading
|
||||
python -m pytest tests/gateway/ -q # Gateway tests
|
||||
python -m pytest tests/tools/ -q # Tool-level tests
|
||||
```
|
||||
|
||||
Worker count above 4 will surface test-ordering flakes that CI never sees.
|
||||
|
||||
Always run the full suite before pushing changes.
|
||||
|
||||
@@ -1208,19 +1208,6 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
||||
logger.debug("Qwen OAuth token seed failed: %s", exc)
|
||||
|
||||
elif provider == "openai-codex":
|
||||
# Respect user suppression — `hermes auth remove openai-codex` marks
|
||||
# the device_code source as suppressed so it won't be re-seeded from
|
||||
# either the Hermes auth store or ~/.codex/auth.json. Without this
|
||||
# gate the removal is instantly undone on the next load_pool() call.
|
||||
codex_suppressed = False
|
||||
try:
|
||||
from hermes_cli.auth import is_source_suppressed
|
||||
codex_suppressed = is_source_suppressed(provider, "device_code")
|
||||
except ImportError:
|
||||
pass
|
||||
if codex_suppressed:
|
||||
return changed, active_sources
|
||||
|
||||
state = _load_provider_state(auth_store, "openai-codex")
|
||||
tokens = state.get("tokens") if isinstance(state, dict) else None
|
||||
# Fallback: import from Codex CLI (~/.codex/auth.json) if Hermes auth
|
||||
|
||||
@@ -4514,34 +4514,6 @@ class HermesCLI:
|
||||
self._restore_modal_input_snapshot()
|
||||
self._invalidate(min_interval=0.0)
|
||||
|
||||
@staticmethod
|
||||
def _compute_model_picker_viewport(
|
||||
selected: int,
|
||||
scroll_offset: int,
|
||||
n: int,
|
||||
term_rows: int,
|
||||
reserved_below: int = 6,
|
||||
panel_chrome: int = 6,
|
||||
min_visible: int = 3,
|
||||
) -> tuple[int, int]:
|
||||
"""Resolve (scroll_offset, visible) for the /model picker viewport.
|
||||
|
||||
``reserved_below`` matches the approval / clarify panels — input area,
|
||||
status bar, and separators below the panel. ``panel_chrome`` covers
|
||||
this panel's own borders + blanks + hint row. The remaining rows hold
|
||||
the scrollable list, with the offset slid to keep ``selected`` on screen.
|
||||
"""
|
||||
max_visible = max(min_visible, term_rows - reserved_below - panel_chrome)
|
||||
if n <= max_visible:
|
||||
return 0, n
|
||||
visible = max_visible
|
||||
if selected < scroll_offset:
|
||||
scroll_offset = selected
|
||||
elif selected >= scroll_offset + visible:
|
||||
scroll_offset = selected - visible + 1
|
||||
scroll_offset = max(0, min(scroll_offset, n - visible))
|
||||
return scroll_offset, visible
|
||||
|
||||
def _apply_model_switch_result(self, result, persist_global: bool) -> None:
|
||||
if not result.success:
|
||||
_cprint(f" ✗ {result.error_message}")
|
||||
@@ -8556,7 +8528,6 @@ class HermesCLI:
|
||||
# --- /model picker modal ---
|
||||
if self._model_picker_state:
|
||||
self._handle_model_picker_selection()
|
||||
event.app.current_buffer.reset()
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
@@ -8722,13 +8693,6 @@ class HermesCLI:
|
||||
state["selected"] = min(max_idx, state.get("selected", 0) + 1)
|
||||
event.app.invalidate()
|
||||
|
||||
@kb.add('escape', filter=Condition(lambda: bool(self._model_picker_state)), eager=True)
|
||||
def model_picker_escape(event):
|
||||
"""ESC closes the /model picker."""
|
||||
self._close_model_picker()
|
||||
event.app.current_buffer.reset()
|
||||
event.app.invalidate()
|
||||
|
||||
# --- History navigation: up/down browse history in normal input mode ---
|
||||
# The TextArea is multiline, so by default up/down only move the cursor.
|
||||
# Buffer.auto_up/auto_down handle both: cursor movement when multi-line,
|
||||
@@ -9530,22 +9494,6 @@ class HermesCLI:
|
||||
|
||||
box_width = _panel_box_width(title, [hint] + choices, min_width=46, max_width=84)
|
||||
inner_text_width = max(8, box_width - 6)
|
||||
selected = state.get("selected", 0)
|
||||
|
||||
# Scrolling viewport: the panel renders into a Window with no max
|
||||
# height, so without limiting visible items the bottom border and
|
||||
# any items past the available terminal rows get clipped on long
|
||||
# provider catalogs (e.g. Ollama Cloud's 36+ models).
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
term_rows = get_app().output.get_size().rows
|
||||
except Exception:
|
||||
term_rows = shutil.get_terminal_size((100, 24)).lines
|
||||
scroll_offset, visible = HermesCLI._compute_model_picker_viewport(
|
||||
selected, state.get("_scroll_offset", 0), len(choices), term_rows,
|
||||
)
|
||||
state["_scroll_offset"] = scroll_offset
|
||||
|
||||
lines = []
|
||||
lines.append(('class:clarify-border', '╭─ '))
|
||||
lines.append(('class:clarify-title', title))
|
||||
@@ -9553,8 +9501,8 @@ class HermesCLI:
|
||||
_append_blank_panel_line(lines, 'class:clarify-border', box_width)
|
||||
_append_panel_line(lines, 'class:clarify-border', 'class:clarify-hint', hint, box_width)
|
||||
_append_blank_panel_line(lines, 'class:clarify-border', box_width)
|
||||
for idx in range(scroll_offset, scroll_offset + visible):
|
||||
choice = choices[idx]
|
||||
selected = state.get("selected", 0)
|
||||
for idx, choice in enumerate(choices):
|
||||
style = 'class:clarify-selected' if idx == selected else 'class:clarify-choice'
|
||||
prefix = '❯ ' if idx == selected else ' '
|
||||
for wrapped in _wrap_panel_text(prefix + choice, inner_text_width, subsequent_indent=' '):
|
||||
|
||||
+107
-171
@@ -27,7 +27,7 @@ except ImportError:
|
||||
except ImportError:
|
||||
msvcrt = None
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
# Add parent directory to path for imports BEFORE repo-level imports.
|
||||
# Without this, standalone invocations (e.g. after `hermes update` reloads
|
||||
@@ -49,25 +49,6 @@ _KNOWN_DELIVERY_PLATFORMS = frozenset({
|
||||
"qqbot",
|
||||
})
|
||||
|
||||
# Platforms that support a configured cron/notification home target, mapped to
|
||||
# the environment variable used by gateway setup/runtime config.
|
||||
_HOME_TARGET_ENV_VARS = {
|
||||
"matrix": "MATRIX_HOME_ROOM",
|
||||
"telegram": "TELEGRAM_HOME_CHANNEL",
|
||||
"discord": "DISCORD_HOME_CHANNEL",
|
||||
"slack": "SLACK_HOME_CHANNEL",
|
||||
"signal": "SIGNAL_HOME_CHANNEL",
|
||||
"mattermost": "MATTERMOST_HOME_CHANNEL",
|
||||
"sms": "SMS_HOME_CHANNEL",
|
||||
"email": "EMAIL_HOME_ADDRESS",
|
||||
"dingtalk": "DINGTALK_HOME_CHANNEL",
|
||||
"feishu": "FEISHU_HOME_CHANNEL",
|
||||
"wecom": "WECOM_HOME_CHANNEL",
|
||||
"weixin": "WEIXIN_HOME_CHANNEL",
|
||||
"bluebubbles": "BLUEBUBBLES_HOME_CHANNEL",
|
||||
"qqbot": "QQ_HOME_CHANNEL",
|
||||
}
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
|
||||
# Sentinel: when a cron agent has nothing new to report, it can start its
|
||||
@@ -95,23 +76,15 @@ def _resolve_origin(job: dict) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
|
||||
def _get_home_target_chat_id(platform_name: str) -> str:
|
||||
"""Return the configured home target chat/room ID for a delivery platform."""
|
||||
env_var = _HOME_TARGET_ENV_VARS.get(platform_name.lower())
|
||||
if not env_var:
|
||||
return ""
|
||||
return os.getenv(env_var, "")
|
||||
|
||||
|
||||
def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[dict]:
|
||||
"""Resolve one concrete auto-delivery target for a cron job."""
|
||||
|
||||
def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
"""Resolve the concrete auto-delivery target for a cron job, if any."""
|
||||
deliver = job.get("deliver", "local")
|
||||
origin = _resolve_origin(job)
|
||||
|
||||
if deliver_value == "local":
|
||||
if deliver == "local":
|
||||
return None
|
||||
|
||||
if deliver_value == "origin":
|
||||
if deliver == "origin":
|
||||
if origin:
|
||||
return {
|
||||
"platform": origin["platform"],
|
||||
@@ -120,8 +93,8 @@ def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[d
|
||||
}
|
||||
# Origin missing (e.g. job created via API/script) — try each
|
||||
# platform's home channel as a fallback instead of silently dropping.
|
||||
for platform_name in _HOME_TARGET_ENV_VARS:
|
||||
chat_id = _get_home_target_chat_id(platform_name)
|
||||
for platform_name in ("matrix", "telegram", "discord", "slack", "bluebubbles"):
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
if chat_id:
|
||||
logger.info(
|
||||
"Job '%s' has deliver=origin but no origin; falling back to %s home channel",
|
||||
@@ -135,8 +108,8 @@ def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[d
|
||||
}
|
||||
return None
|
||||
|
||||
if ":" in deliver_value:
|
||||
platform_name, rest = deliver_value.split(":", 1)
|
||||
if ":" in deliver:
|
||||
platform_name, rest = deliver.split(":", 1)
|
||||
platform_key = platform_name.lower()
|
||||
|
||||
from tools.send_message_tool import _parse_target_ref
|
||||
@@ -166,7 +139,7 @@ def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[d
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
|
||||
platform_name = deliver_value
|
||||
platform_name = deliver
|
||||
if origin and origin.get("platform") == platform_name:
|
||||
return {
|
||||
"platform": platform_name,
|
||||
@@ -176,7 +149,7 @@ def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[d
|
||||
|
||||
if platform_name.lower() not in _KNOWN_DELIVERY_PLATFORMS:
|
||||
return None
|
||||
chat_id = _get_home_target_chat_id(platform_name)
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
if not chat_id:
|
||||
return None
|
||||
|
||||
@@ -187,30 +160,6 @@ def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[d
|
||||
}
|
||||
|
||||
|
||||
def _resolve_delivery_targets(job: dict) -> List[dict]:
|
||||
"""Resolve all concrete auto-delivery targets for a cron job (supports comma-separated deliver)."""
|
||||
deliver = job.get("deliver", "local")
|
||||
if deliver == "local":
|
||||
return []
|
||||
parts = [p.strip() for p in str(deliver).split(",") if p.strip()]
|
||||
seen = set()
|
||||
targets = []
|
||||
for part in parts:
|
||||
target = _resolve_single_delivery_target(job, part)
|
||||
if target:
|
||||
key = (target["platform"].lower(), str(target["chat_id"]), target.get("thread_id"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
targets.append(target)
|
||||
return targets
|
||||
|
||||
|
||||
def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
"""Resolve the concrete auto-delivery target for a cron job, if any."""
|
||||
targets = _resolve_delivery_targets(job)
|
||||
return targets[0] if targets else None
|
||||
|
||||
|
||||
# Media extension sets — keep in sync with gateway/platforms/base.py:_process_message_background
|
||||
_AUDIO_EXTS = frozenset({'.ogg', '.opus', '.mp3', '.wav', '.m4a'})
|
||||
_VIDEO_EXTS = frozenset({'.mp4', '.mov', '.avi', '.mkv', '.webm', '.3gp'})
|
||||
@@ -251,7 +200,7 @@ def _send_media_via_adapter(adapter, chat_id: str, media_files: list, metadata:
|
||||
|
||||
def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Optional[str]:
|
||||
"""
|
||||
Deliver job output to the configured target(s) (origin chat, specific platform, etc.).
|
||||
Deliver job output to the configured target (origin chat, specific platform, etc.).
|
||||
|
||||
When ``adapters`` and ``loop`` are provided (gateway is running), tries to
|
||||
use the live adapter first — this supports E2EE rooms (e.g. Matrix) where
|
||||
@@ -260,14 +209,33 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
|
||||
Returns None on success, or an error string on failure.
|
||||
"""
|
||||
targets = _resolve_delivery_targets(job)
|
||||
if not targets:
|
||||
target = _resolve_delivery_target(job)
|
||||
if not target:
|
||||
if job.get("deliver", "local") != "local":
|
||||
msg = f"no delivery target resolved for deliver={job.get('deliver', 'local')}"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
return None # local-only jobs don't deliver — not a failure
|
||||
|
||||
platform_name = target["platform"]
|
||||
chat_id = target["chat_id"]
|
||||
thread_id = target.get("thread_id")
|
||||
|
||||
# Diagnostic: log thread_id for topic-aware delivery debugging
|
||||
origin = job.get("origin") or {}
|
||||
origin_thread = origin.get("thread_id")
|
||||
if origin_thread and not thread_id:
|
||||
logger.warning(
|
||||
"Job '%s': origin has thread_id=%s but delivery target lost it "
|
||||
"(deliver=%s, target=%s)",
|
||||
job["id"], origin_thread, job.get("deliver", "local"), target,
|
||||
)
|
||||
elif thread_id:
|
||||
logger.debug(
|
||||
"Job '%s': delivering to %s:%s thread_id=%s",
|
||||
job["id"], platform_name, chat_id, thread_id,
|
||||
)
|
||||
|
||||
from tools.send_message_tool import _send_to_platform
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
|
||||
@@ -290,6 +258,24 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
"bluebubbles": Platform.BLUEBUBBLES,
|
||||
"qqbot": Platform.QQBOT,
|
||||
}
|
||||
platform = platform_map.get(platform_name.lower())
|
||||
if not platform:
|
||||
msg = f"unknown platform '{platform_name}'"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
try:
|
||||
config = load_gateway_config()
|
||||
except Exception as e:
|
||||
msg = f"failed to load gateway config: {e}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
msg = f"platform '{platform_name}' not configured/enabled"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
# Optionally wrap the content with a header/footer so the user knows this
|
||||
# is a cron delivery. Wrapping is on by default; set cron.wrap_response: false
|
||||
@@ -318,117 +304,67 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
media_files, cleaned_delivery_content = BasePlatformAdapter.extract_media(delivery_content)
|
||||
|
||||
# Prefer the live adapter when the gateway is running — this supports E2EE
|
||||
# rooms (e.g. Matrix) where the standalone HTTP path cannot encrypt.
|
||||
runtime_adapter = (adapters or {}).get(platform)
|
||||
if runtime_adapter is not None and loop is not None and getattr(loop, "is_running", lambda: False)():
|
||||
send_metadata = {"thread_id": thread_id} if thread_id else None
|
||||
try:
|
||||
# Send cleaned text (MEDIA tags stripped) — not the raw content
|
||||
text_to_send = cleaned_delivery_content.strip()
|
||||
adapter_ok = True
|
||||
if text_to_send:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
runtime_adapter.send(chat_id, text_to_send, metadata=send_metadata),
|
||||
loop,
|
||||
)
|
||||
send_result = future.result(timeout=60)
|
||||
if send_result and not getattr(send_result, "success", True):
|
||||
err = getattr(send_result, "error", "unknown")
|
||||
logger.warning(
|
||||
"Job '%s': live adapter send to %s:%s failed (%s), falling back to standalone",
|
||||
job["id"], platform_name, chat_id, err,
|
||||
)
|
||||
adapter_ok = False # fall through to standalone path
|
||||
|
||||
# Send extracted media files as native attachments via the live adapter
|
||||
if adapter_ok and media_files:
|
||||
_send_media_via_adapter(runtime_adapter, chat_id, media_files, send_metadata, loop, job)
|
||||
|
||||
if adapter_ok:
|
||||
logger.info("Job '%s': delivered to %s:%s via live adapter", job["id"], platform_name, chat_id)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Job '%s': live adapter delivery to %s:%s failed (%s), falling back to standalone",
|
||||
job["id"], platform_name, chat_id, e,
|
||||
)
|
||||
|
||||
# Standalone path: run the async send in a fresh event loop (safe from any thread)
|
||||
coro = _send_to_platform(platform, pconfig, chat_id, cleaned_delivery_content, thread_id=thread_id, media_files=media_files)
|
||||
try:
|
||||
config = load_gateway_config()
|
||||
result = asyncio.run(coro)
|
||||
except RuntimeError:
|
||||
# asyncio.run() checks for a running loop before awaiting the coroutine;
|
||||
# when it raises, the original coro was never started — close it to
|
||||
# prevent "coroutine was never awaited" RuntimeWarning, then retry in a
|
||||
# fresh thread that has no running loop.
|
||||
coro.close()
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, cleaned_delivery_content, thread_id=thread_id, media_files=media_files))
|
||||
result = future.result(timeout=30)
|
||||
except Exception as e:
|
||||
msg = f"failed to load gateway config: {e}"
|
||||
msg = f"delivery to {platform_name}:{chat_id} failed: {e}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
delivery_errors = []
|
||||
if result and result.get("error"):
|
||||
msg = f"delivery error: {result['error']}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
for target in targets:
|
||||
platform_name = target["platform"]
|
||||
chat_id = target["chat_id"]
|
||||
thread_id = target.get("thread_id")
|
||||
|
||||
# Diagnostic: log thread_id for topic-aware delivery debugging
|
||||
origin = job.get("origin") or {}
|
||||
origin_thread = origin.get("thread_id")
|
||||
if origin_thread and not thread_id:
|
||||
logger.warning(
|
||||
"Job '%s': origin has thread_id=%s but delivery target lost it "
|
||||
"(deliver=%s, target=%s)",
|
||||
job["id"], origin_thread, job.get("deliver", "local"), target,
|
||||
)
|
||||
elif thread_id:
|
||||
logger.debug(
|
||||
"Job '%s': delivering to %s:%s thread_id=%s",
|
||||
job["id"], platform_name, chat_id, thread_id,
|
||||
)
|
||||
|
||||
platform = platform_map.get(platform_name.lower())
|
||||
if not platform:
|
||||
msg = f"unknown platform '{platform_name}'"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
delivery_errors.append(msg)
|
||||
continue
|
||||
|
||||
# Prefer the live adapter when the gateway is running — this supports E2EE
|
||||
# rooms (e.g. Matrix) where the standalone HTTP path cannot encrypt.
|
||||
runtime_adapter = (adapters or {}).get(platform)
|
||||
delivered = False
|
||||
if runtime_adapter is not None and loop is not None and getattr(loop, "is_running", lambda: False)():
|
||||
send_metadata = {"thread_id": thread_id} if thread_id else None
|
||||
try:
|
||||
# Send cleaned text (MEDIA tags stripped) — not the raw content
|
||||
text_to_send = cleaned_delivery_content.strip()
|
||||
adapter_ok = True
|
||||
if text_to_send:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
runtime_adapter.send(chat_id, text_to_send, metadata=send_metadata),
|
||||
loop,
|
||||
)
|
||||
send_result = future.result(timeout=60)
|
||||
if send_result and not getattr(send_result, "success", True):
|
||||
err = getattr(send_result, "error", "unknown")
|
||||
logger.warning(
|
||||
"Job '%s': live adapter send to %s:%s failed (%s), falling back to standalone",
|
||||
job["id"], platform_name, chat_id, err,
|
||||
)
|
||||
adapter_ok = False # fall through to standalone path
|
||||
|
||||
# Send extracted media files as native attachments via the live adapter
|
||||
if adapter_ok and media_files:
|
||||
_send_media_via_adapter(runtime_adapter, chat_id, media_files, send_metadata, loop, job)
|
||||
|
||||
if adapter_ok:
|
||||
logger.info("Job '%s': delivered to %s:%s via live adapter", job["id"], platform_name, chat_id)
|
||||
delivered = True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Job '%s': live adapter delivery to %s:%s failed (%s), falling back to standalone",
|
||||
job["id"], platform_name, chat_id, e,
|
||||
)
|
||||
|
||||
if not delivered:
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
msg = f"platform '{platform_name}' not configured/enabled"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
delivery_errors.append(msg)
|
||||
continue
|
||||
|
||||
# Standalone path: run the async send in a fresh event loop (safe from any thread)
|
||||
coro = _send_to_platform(platform, pconfig, chat_id, cleaned_delivery_content, thread_id=thread_id, media_files=media_files)
|
||||
try:
|
||||
result = asyncio.run(coro)
|
||||
except RuntimeError:
|
||||
# asyncio.run() checks for a running loop before awaiting the coroutine;
|
||||
# when it raises, the original coro was never started — close it to
|
||||
# prevent "coroutine was never awaited" RuntimeWarning, then retry in a
|
||||
# fresh thread that has no running loop.
|
||||
coro.close()
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, cleaned_delivery_content, thread_id=thread_id, media_files=media_files))
|
||||
result = future.result(timeout=30)
|
||||
except Exception as e:
|
||||
msg = f"delivery to {platform_name}:{chat_id} failed: {e}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
delivery_errors.append(msg)
|
||||
continue
|
||||
|
||||
if result and result.get("error"):
|
||||
msg = f"delivery error: {result['error']}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
delivery_errors.append(msg)
|
||||
continue
|
||||
|
||||
logger.info("Job '%s': delivered to %s:%s", job["id"], platform_name, chat_id)
|
||||
|
||||
if delivery_errors:
|
||||
return "; ".join(delivery_errors)
|
||||
logger.info("Job '%s': delivered to %s:%s", job["id"], platform_name, chat_id)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -307,14 +307,6 @@ class GatewayConfig:
|
||||
# QQBot uses extra dict for app credentials
|
||||
elif platform == Platform.QQBOT and config.extra.get("app_id") and config.extra.get("client_secret"):
|
||||
connected.append(platform)
|
||||
# DingTalk uses client_id/client_secret from config.extra or env vars
|
||||
elif platform == Platform.DINGTALK and (
|
||||
config.extra.get("client_id") or os.getenv("DINGTALK_CLIENT_ID")
|
||||
) and (
|
||||
config.extra.get("client_secret") or os.getenv("DINGTALK_CLIENT_SECRET")
|
||||
):
|
||||
connected.append(platform)
|
||||
|
||||
return connected
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
|
||||
@@ -625,20 +617,6 @@ def load_gateway_config() -> GatewayConfig:
|
||||
if isinstance(ntc, list):
|
||||
ntc = ",".join(str(v) for v in ntc)
|
||||
os.environ["DISCORD_NO_THREAD_CHANNELS"] = str(ntc)
|
||||
# allow_mentions: granular control over what the bot can ping.
|
||||
# Safe defaults (no @everyone/roles) are applied in the adapter;
|
||||
# these YAML keys only override when set and let users opt back
|
||||
# into unsafe modes (e.g. roles=true) if they actually want it.
|
||||
allow_mentions_cfg = discord_cfg.get("allow_mentions")
|
||||
if isinstance(allow_mentions_cfg, dict):
|
||||
for yaml_key, env_key in (
|
||||
("everyone", "DISCORD_ALLOW_MENTION_EVERYONE"),
|
||||
("roles", "DISCORD_ALLOW_MENTION_ROLES"),
|
||||
("users", "DISCORD_ALLOW_MENTION_USERS"),
|
||||
("replied_user", "DISCORD_ALLOW_MENTION_REPLIED_USER"),
|
||||
):
|
||||
if yaml_key in allow_mentions_cfg and not os.getenv(env_key):
|
||||
os.environ[env_key] = str(allow_mentions_cfg[yaml_key]).lower()
|
||||
|
||||
# Telegram settings → env vars (env vars take precedence)
|
||||
telegram_cfg = yaml_cfg.get("telegram", {})
|
||||
@@ -685,24 +663,6 @@ def load_gateway_config() -> GatewayConfig:
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["WHATSAPP_FREE_RESPONSE_CHATS"] = str(frc)
|
||||
|
||||
# DingTalk settings → env vars (env vars take precedence)
|
||||
dingtalk_cfg = yaml_cfg.get("dingtalk", {})
|
||||
if isinstance(dingtalk_cfg, dict):
|
||||
if "require_mention" in dingtalk_cfg and not os.getenv("DINGTALK_REQUIRE_MENTION"):
|
||||
os.environ["DINGTALK_REQUIRE_MENTION"] = str(dingtalk_cfg["require_mention"]).lower()
|
||||
if "mention_patterns" in dingtalk_cfg and not os.getenv("DINGTALK_MENTION_PATTERNS"):
|
||||
os.environ["DINGTALK_MENTION_PATTERNS"] = json.dumps(dingtalk_cfg["mention_patterns"])
|
||||
frc = dingtalk_cfg.get("free_response_chats")
|
||||
if frc is not None and not os.getenv("DINGTALK_FREE_RESPONSE_CHATS"):
|
||||
if isinstance(frc, list):
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["DINGTALK_FREE_RESPONSE_CHATS"] = str(frc)
|
||||
allowed = dingtalk_cfg.get("allowed_users")
|
||||
if allowed is not None and not os.getenv("DINGTALK_ALLOWED_USERS"):
|
||||
if isinstance(allowed, list):
|
||||
allowed = ",".join(str(v) for v in allowed)
|
||||
os.environ["DINGTALK_ALLOWED_USERS"] = str(allowed)
|
||||
|
||||
# Matrix settings → env vars (env vars take precedence)
|
||||
matrix_cfg = yaml_cfg.get("matrix", {})
|
||||
if isinstance(matrix_cfg, dict):
|
||||
@@ -1046,25 +1006,6 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
if webhook_secret:
|
||||
config.platforms[Platform.WEBHOOK].extra["secret"] = webhook_secret
|
||||
|
||||
# DingTalk
|
||||
dingtalk_client_id = os.getenv("DINGTALK_CLIENT_ID")
|
||||
dingtalk_client_secret = os.getenv("DINGTALK_CLIENT_SECRET")
|
||||
if dingtalk_client_id and dingtalk_client_secret:
|
||||
if Platform.DINGTALK not in config.platforms:
|
||||
config.platforms[Platform.DINGTALK] = PlatformConfig()
|
||||
config.platforms[Platform.DINGTALK].enabled = True
|
||||
config.platforms[Platform.DINGTALK].extra.update({
|
||||
"client_id": dingtalk_client_id,
|
||||
"client_secret": dingtalk_client_secret,
|
||||
})
|
||||
dingtalk_home = os.getenv("DINGTALK_HOME_CHANNEL")
|
||||
if dingtalk_home:
|
||||
config.platforms[Platform.DINGTALK].home_channel = HomeChannel(
|
||||
platform=Platform.DINGTALK,
|
||||
chat_id=dingtalk_home,
|
||||
name=os.getenv("DINGTALK_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Feishu / Lark
|
||||
feishu_app_id = os.getenv("FEISHU_APP_ID")
|
||||
feishu_app_secret = os.getenv("FEISHU_APP_SECRET")
|
||||
|
||||
@@ -1991,7 +1991,6 @@ class BasePlatformAdapter(ABC):
|
||||
chat_topic: Optional[str] = None,
|
||||
user_id_alt: Optional[str] = None,
|
||||
chat_id_alt: Optional[str] = None,
|
||||
is_bot: bool = False,
|
||||
) -> SessionSource:
|
||||
"""Helper to build a SessionSource for this platform."""
|
||||
# Normalize empty topic to None
|
||||
@@ -2008,7 +2007,6 @@ class BasePlatformAdapter(ABC):
|
||||
chat_topic=chat_topic.strip() if chat_topic else None,
|
||||
user_id_alt=user_id_alt,
|
||||
chat_id_alt=chat_id_alt,
|
||||
is_bot=is_bot,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -12,27 +12,18 @@ Configuration in config.yaml:
|
||||
platforms:
|
||||
dingtalk:
|
||||
enabled: true
|
||||
# Optional group-chat gating (mirrors Slack/Telegram/Discord):
|
||||
require_mention: true # or DINGTALK_REQUIRE_MENTION env var
|
||||
# free_response_chats: # conversations that skip require_mention
|
||||
# - cidABC==
|
||||
# mention_patterns: # regex wake-words (e.g. Chinese bot names)
|
||||
# - "^小马"
|
||||
# allowed_users: # staff_id or sender_id list; "*" = any
|
||||
# - "manager1234"
|
||||
extra:
|
||||
client_id: "your-app-key" # or DINGTALK_CLIENT_ID env var
|
||||
client_secret: "your-secret" # or DINGTALK_CLIENT_SECRET env var
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
try:
|
||||
import dingtalk_stream
|
||||
@@ -101,10 +92,6 @@ class DingTalkAdapter(BasePlatformAdapter):
|
||||
# Map chat_id -> session_webhook for reply routing
|
||||
self._session_webhooks: Dict[str, str] = {}
|
||||
|
||||
# Group-chat gating (mirrors Slack/Telegram/Discord/WhatsApp conventions)
|
||||
self._mention_patterns: List[re.Pattern] = self._compile_mention_patterns()
|
||||
self._allowed_users: Set[str] = self._load_allowed_users()
|
||||
|
||||
# -- Connection lifecycle -----------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
@@ -167,19 +154,12 @@ class DingTalkAdapter(BasePlatformAdapter):
|
||||
self._running = False
|
||||
self._mark_disconnected()
|
||||
|
||||
websocket = getattr(self._stream_client, "websocket", None)
|
||||
if websocket is not None:
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception as e:
|
||||
logger.debug("[%s] websocket close during disconnect failed: %s", self.name, e)
|
||||
|
||||
if self._stream_task:
|
||||
self._stream_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self._stream_task, timeout=2.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
logger.debug("[%s] stream task did not exit cleanly during disconnect", self.name)
|
||||
await self._stream_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._stream_task = None
|
||||
|
||||
if self._http_client:
|
||||
@@ -191,118 +171,6 @@ class DingTalkAdapter(BasePlatformAdapter):
|
||||
self._dedup.clear()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
# -- Group gating --------------------------------------------------------
|
||||
|
||||
def _dingtalk_require_mention(self) -> bool:
|
||||
"""Return whether group chats should require an explicit bot trigger."""
|
||||
configured = self.config.extra.get("require_mention")
|
||||
if configured is not None:
|
||||
if isinstance(configured, str):
|
||||
return configured.lower() in ("true", "1", "yes", "on")
|
||||
return bool(configured)
|
||||
return os.getenv("DINGTALK_REQUIRE_MENTION", "false").lower() in ("true", "1", "yes", "on")
|
||||
|
||||
def _dingtalk_free_response_chats(self) -> Set[str]:
|
||||
raw = self.config.extra.get("free_response_chats")
|
||||
if raw is None:
|
||||
raw = os.getenv("DINGTALK_FREE_RESPONSE_CHATS", "")
|
||||
if isinstance(raw, list):
|
||||
return {str(part).strip() for part in raw if str(part).strip()}
|
||||
return {part.strip() for part in str(raw).split(",") if part.strip()}
|
||||
|
||||
def _compile_mention_patterns(self) -> List[re.Pattern]:
|
||||
"""Compile optional regex wake-word patterns for group triggers."""
|
||||
patterns = self.config.extra.get("mention_patterns") if self.config.extra else None
|
||||
if patterns is None:
|
||||
raw = os.getenv("DINGTALK_MENTION_PATTERNS", "").strip()
|
||||
if raw:
|
||||
try:
|
||||
loaded = json.loads(raw)
|
||||
except Exception:
|
||||
loaded = [part.strip() for part in raw.splitlines() if part.strip()]
|
||||
if not loaded:
|
||||
loaded = [part.strip() for part in raw.split(",") if part.strip()]
|
||||
patterns = loaded
|
||||
|
||||
if patterns is None:
|
||||
return []
|
||||
if isinstance(patterns, str):
|
||||
patterns = [patterns]
|
||||
if not isinstance(patterns, list):
|
||||
logger.warning(
|
||||
"[%s] dingtalk mention_patterns must be a list or string; got %s",
|
||||
self.name,
|
||||
type(patterns).__name__,
|
||||
)
|
||||
return []
|
||||
|
||||
compiled: List[re.Pattern] = []
|
||||
for pattern in patterns:
|
||||
if not isinstance(pattern, str) or not pattern.strip():
|
||||
continue
|
||||
try:
|
||||
compiled.append(re.compile(pattern, re.IGNORECASE))
|
||||
except re.error as exc:
|
||||
logger.warning("[%s] Invalid DingTalk mention pattern %r: %s", self.name, pattern, exc)
|
||||
if compiled:
|
||||
logger.info("[%s] Loaded %d DingTalk mention pattern(s)", self.name, len(compiled))
|
||||
return compiled
|
||||
|
||||
def _load_allowed_users(self) -> Set[str]:
|
||||
"""Load allowed-users list from config.extra or env var.
|
||||
|
||||
IDs are matched case-insensitively against the sender's ``staff_id`` and
|
||||
``sender_id``. A wildcard ``*`` disables the check.
|
||||
"""
|
||||
raw = self.config.extra.get("allowed_users") if self.config.extra else None
|
||||
if raw is None:
|
||||
raw = os.getenv("DINGTALK_ALLOWED_USERS", "")
|
||||
if isinstance(raw, list):
|
||||
items = [str(part).strip() for part in raw if str(part).strip()]
|
||||
else:
|
||||
items = [part.strip() for part in str(raw).split(",") if part.strip()]
|
||||
return {item.lower() for item in items}
|
||||
|
||||
def _is_user_allowed(self, sender_id: str, sender_staff_id: str) -> bool:
|
||||
if not self._allowed_users or "*" in self._allowed_users:
|
||||
return True
|
||||
candidates = {(sender_id or "").lower(), (sender_staff_id or "").lower()}
|
||||
candidates.discard("")
|
||||
return bool(candidates & self._allowed_users)
|
||||
|
||||
def _message_mentions_bot(self, message: "ChatbotMessage") -> bool:
|
||||
"""True if the bot was @-mentioned in a group message.
|
||||
|
||||
dingtalk-stream sets ``is_in_at_list`` on the incoming ChatbotMessage
|
||||
when the bot is addressed via @-mention.
|
||||
"""
|
||||
return bool(getattr(message, "is_in_at_list", False))
|
||||
|
||||
def _message_matches_mention_patterns(self, text: str) -> bool:
|
||||
if not text or not self._mention_patterns:
|
||||
return False
|
||||
return any(pattern.search(text) for pattern in self._mention_patterns)
|
||||
|
||||
def _should_process_message(self, message: "ChatbotMessage", text: str, is_group: bool, chat_id: str) -> bool:
|
||||
"""Apply DingTalk group trigger rules.
|
||||
|
||||
DMs remain unrestricted (subject to ``allowed_users`` which is enforced
|
||||
earlier). Group messages are accepted when:
|
||||
- the chat is explicitly allowlisted in ``free_response_chats``
|
||||
- ``require_mention`` is disabled
|
||||
- the bot is @mentioned (``is_in_at_list``)
|
||||
- the text matches a configured regex wake-word pattern
|
||||
"""
|
||||
if not is_group:
|
||||
return True
|
||||
if chat_id and chat_id in self._dingtalk_free_response_chats():
|
||||
return True
|
||||
if not self._dingtalk_require_mention():
|
||||
return True
|
||||
if self._message_mentions_bot(message):
|
||||
return True
|
||||
return self._message_matches_mention_patterns(text)
|
||||
|
||||
# -- Inbound message processing -----------------------------------------
|
||||
|
||||
async def _on_message(self, message: "ChatbotMessage") -> None:
|
||||
@@ -328,22 +196,6 @@ class DingTalkAdapter(BasePlatformAdapter):
|
||||
chat_id = conversation_id or sender_id
|
||||
chat_type = "group" if is_group else "dm"
|
||||
|
||||
# Allowed-users gate (applies to both DM and group)
|
||||
if not self._is_user_allowed(sender_id, sender_staff_id):
|
||||
logger.debug(
|
||||
"[%s] Dropping message from non-allowlisted user staff_id=%s sender_id=%s",
|
||||
self.name, sender_staff_id, sender_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Group mention/pattern gate
|
||||
if not self._should_process_message(message, text, is_group, chat_id):
|
||||
logger.debug(
|
||||
"[%s] Dropping group message that failed mention gate message_id=%s chat_id=%s",
|
||||
self.name, msg_id, chat_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Store session webhook for reply routing (validate origin to prevent SSRF)
|
||||
session_webhook = getattr(message, "session_webhook", None) or ""
|
||||
if session_webhook and chat_id and _DINGTALK_WEBHOOK_RE.match(session_webhook):
|
||||
@@ -483,39 +335,13 @@ class _IncomingHandler(ChatbotHandler if DINGTALK_STREAM_AVAILABLE else object):
|
||||
"""Called by dingtalk-stream when a message arrives.
|
||||
|
||||
dingtalk-stream >= 0.24 passes a CallbackMessage whose `.data` contains
|
||||
the chatbot payload. Convert it to ChatbotMessage via
|
||||
``ChatbotMessage.from_dict()``.
|
||||
|
||||
Message processing is dispatched as a background task so that this
|
||||
method returns the ACK immediately — blocking here would prevent the
|
||||
SDK from sending heartbeats, eventually causing a disconnect.
|
||||
the chatbot payload. Convert it to ChatbotMessage and await the adapter
|
||||
handler directly on the main event loop.
|
||||
"""
|
||||
try:
|
||||
data = callback_message.data
|
||||
chatbot_msg = ChatbotMessage.from_dict(data)
|
||||
|
||||
# Ensure session_webhook is populated even if the SDK's
|
||||
# from_dict() did not map it (field name mismatch across
|
||||
# SDK versions).
|
||||
if not getattr(chatbot_msg, "session_webhook", None):
|
||||
webhook = (
|
||||
data.get("sessionWebhook")
|
||||
or data.get("session_webhook")
|
||||
or ""
|
||||
)
|
||||
if webhook:
|
||||
chatbot_msg.session_webhook = webhook
|
||||
|
||||
# Fire-and-forget: return ACK immediately, process in background.
|
||||
asyncio.create_task(self._safe_on_message(chatbot_msg))
|
||||
except Exception:
|
||||
logger.exception("[DingTalk] Error preparing incoming message")
|
||||
|
||||
return dingtalk_stream.AckMessage.STATUS_OK, "OK"
|
||||
|
||||
async def _safe_on_message(self, chatbot_msg: "ChatbotMessage") -> None:
|
||||
"""Wrapper that catches exceptions from _on_message."""
|
||||
try:
|
||||
chatbot_msg = ChatbotMessage.from_dict(callback_message.data)
|
||||
await self._adapter._on_message(chatbot_msg)
|
||||
except Exception:
|
||||
logger.exception("[DingTalk] Error processing incoming message")
|
||||
|
||||
return dingtalk_stream.AckMessage.STATUS_OK, "OK"
|
||||
|
||||
+90
-385
@@ -51,9 +51,7 @@ from gateway.platforms.base import (
|
||||
ProcessingOutcome,
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_url,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
@@ -82,41 +80,6 @@ def check_discord_requirements() -> bool:
|
||||
return DISCORD_AVAILABLE
|
||||
|
||||
|
||||
def _build_allowed_mentions():
|
||||
"""Build Discord ``AllowedMentions`` with safe defaults, overridable via env.
|
||||
|
||||
Discord bots default to parsing ``@everyone``, ``@here``, role pings, and
|
||||
user pings when ``allowed_mentions`` is unset on the client — any LLM
|
||||
output or echoed user content that contains ``@everyone`` would therefore
|
||||
ping the whole server. We explicitly deny ``@everyone`` and role pings
|
||||
by default and keep user / replied-user pings enabled so normal
|
||||
conversation still works.
|
||||
|
||||
Override via environment variables (or ``discord.allow_mentions.*`` in
|
||||
config.yaml):
|
||||
|
||||
DISCORD_ALLOW_MENTION_EVERYONE default false — @everyone + @here
|
||||
DISCORD_ALLOW_MENTION_ROLES default false — @role pings
|
||||
DISCORD_ALLOW_MENTION_USERS default true — @user pings
|
||||
DISCORD_ALLOW_MENTION_REPLIED_USER default true — reply-ping author
|
||||
"""
|
||||
if not DISCORD_AVAILABLE:
|
||||
return None
|
||||
|
||||
def _b(name: str, default: bool) -> bool:
|
||||
raw = os.getenv(name, "").strip().lower()
|
||||
if not raw:
|
||||
return default
|
||||
return raw in ("true", "1", "yes", "on")
|
||||
|
||||
return discord.AllowedMentions(
|
||||
everyone=_b("DISCORD_ALLOW_MENTION_EVERYONE", False),
|
||||
roles=_b("DISCORD_ALLOW_MENTION_ROLES", False),
|
||||
users=_b("DISCORD_ALLOW_MENTION_USERS", True),
|
||||
replied_user=_b("DISCORD_ALLOW_MENTION_REPLIED_USER", True),
|
||||
)
|
||||
|
||||
|
||||
class VoiceReceiver:
|
||||
"""Captures and decodes voice audio from a Discord voice channel.
|
||||
|
||||
@@ -495,7 +458,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._client: Optional[commands.Bot] = None
|
||||
self._ready_event = asyncio.Event()
|
||||
self._allowed_user_ids: set = set() # For button approval authorization
|
||||
self._allowed_role_ids: set = set() # For DISCORD_ALLOWED_ROLES filtering
|
||||
# Voice channel state (per-guild)
|
||||
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
|
||||
# Text batching: merge rapid successive messages (Telegram-style)
|
||||
@@ -574,15 +536,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if uid.strip()
|
||||
}
|
||||
|
||||
# Parse DISCORD_ALLOWED_ROLES — comma-separated role IDs.
|
||||
# Users with ANY of these roles can interact with the bot.
|
||||
roles_env = os.getenv("DISCORD_ALLOWED_ROLES", "")
|
||||
if roles_env:
|
||||
self._allowed_role_ids = {
|
||||
int(rid.strip()) for rid in roles_env.split(",")
|
||||
if rid.strip().isdigit()
|
||||
}
|
||||
|
||||
# Set up intents.
|
||||
# Message Content is required for normal text replies.
|
||||
# Server Members is only needed when the allowlist contains usernames
|
||||
@@ -594,10 +547,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
intents.message_content = True
|
||||
intents.dm_messages = True
|
||||
intents.guild_messages = True
|
||||
intents.members = (
|
||||
any(not entry.isdigit() for entry in self._allowed_user_ids)
|
||||
or bool(self._allowed_role_ids) # Need members intent for role lookup
|
||||
)
|
||||
intents.members = any(not entry.isdigit() for entry in self._allowed_user_ids)
|
||||
intents.voice_states = True
|
||||
|
||||
# Resolve proxy (DISCORD_PROXY > generic env vars > macOS system proxy)
|
||||
@@ -606,15 +556,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if proxy_url:
|
||||
logger.info("[%s] Using proxy for Discord: %s", self.name, proxy_url)
|
||||
|
||||
# Create bot — proxy= for HTTP, connector= for SOCKS.
|
||||
# allowed_mentions is set with safe defaults (no @everyone/roles)
|
||||
# so LLM output or echoed user content can't ping the whole
|
||||
# server; override per DISCORD_ALLOW_MENTION_* env vars or the
|
||||
# discord.allow_mentions.* block in config.yaml.
|
||||
# Create bot — proxy= for HTTP, connector= for SOCKS
|
||||
self._client = commands.Bot(
|
||||
command_prefix="!", # Not really used, we handle raw messages
|
||||
intents=intents,
|
||||
allowed_mentions=_build_allowed_mentions(),
|
||||
**proxy_kwargs_for_bot(proxy_url),
|
||||
)
|
||||
adapter_self = self # capture for closure
|
||||
@@ -649,13 +594,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
return
|
||||
|
||||
# Check if the message author is in the allowed user list
|
||||
if not self._is_allowed_user(str(message.author.id)):
|
||||
return
|
||||
|
||||
# Bot message filtering (DISCORD_ALLOW_BOTS):
|
||||
# "none" — ignore all other bots (default)
|
||||
# "mentions" — accept bot messages only when they @mention us
|
||||
# "all" — accept all bot messages
|
||||
# Must run BEFORE the user allowlist check so that bots
|
||||
# permitted by DISCORD_ALLOW_BOTS are not rejected for
|
||||
# not being in DISCORD_ALLOWED_USERS (fixes #4466).
|
||||
if getattr(message.author, "bot", False):
|
||||
allow_bots = os.getenv("DISCORD_ALLOW_BOTS", "none").lower().strip()
|
||||
if allow_bots == "none":
|
||||
@@ -663,12 +609,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
elif allow_bots == "mentions":
|
||||
if not self._client.user or self._client.user not in message.mentions:
|
||||
return
|
||||
# "all" falls through; bot is permitted — skip the
|
||||
# human-user allowlist below (bots aren't in it).
|
||||
else:
|
||||
# Non-bot: enforce the configured user/role allowlists.
|
||||
if not self._is_allowed_user(str(message.author.id), message.author):
|
||||
return
|
||||
# "all" falls through to handle_message
|
||||
|
||||
# Multi-agent filtering: if the message mentions specific bots
|
||||
# but NOT this bot, the sender is talking to another agent —
|
||||
@@ -892,10 +833,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if reply_to and self._reply_to_mode != "off":
|
||||
try:
|
||||
ref_msg = await channel.fetch_message(int(reply_to))
|
||||
if hasattr(ref_msg, "to_reference"):
|
||||
reference = ref_msg.to_reference(fail_if_not_exists=False)
|
||||
else:
|
||||
reference = ref_msg
|
||||
reference = ref_msg
|
||||
except Exception as e:
|
||||
logger.debug("Could not fetch reply-to message: %s", e)
|
||||
|
||||
@@ -913,20 +851,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
err_text = str(e)
|
||||
if (
|
||||
chunk_reference is not None
|
||||
and (
|
||||
(
|
||||
"error code: 50035" in err_text
|
||||
and "Cannot reply to a system message" in err_text
|
||||
)
|
||||
or "error code: 10008" in err_text
|
||||
)
|
||||
and "error code: 50035" in err_text
|
||||
and "Cannot reply to a system message" in err_text
|
||||
):
|
||||
logger.warning(
|
||||
"[%s] Reply target %s rejected the reply reference; retrying send without reply reference",
|
||||
"[%s] Reply target %s is a Discord system message; retrying send without reply reference",
|
||||
self.name,
|
||||
reply_to,
|
||||
)
|
||||
reference = None
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
reference=None,
|
||||
@@ -1378,48 +1310,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _is_allowed_user(self, user_id: str, author=None) -> bool:
|
||||
"""Check if user is allowed via DISCORD_ALLOWED_USERS or DISCORD_ALLOWED_ROLES.
|
||||
|
||||
Uses OR semantics: if the user matches EITHER allowlist, they're allowed.
|
||||
If both allowlists are empty, everyone is allowed (backwards compatible).
|
||||
When author is a Member, checks .roles directly; otherwise falls back
|
||||
to scanning the bot's mutual guilds for a Member record.
|
||||
"""
|
||||
# ``getattr`` fallbacks here guard against test fixtures that build
|
||||
# an adapter via ``object.__new__(DiscordAdapter)`` and skip __init__
|
||||
# (see AGENTS.md pitfall #17 — same pattern as gateway.run).
|
||||
allowed_users = getattr(self, "_allowed_user_ids", set())
|
||||
allowed_roles = getattr(self, "_allowed_role_ids", set())
|
||||
has_users = bool(allowed_users)
|
||||
has_roles = bool(allowed_roles)
|
||||
if not has_users and not has_roles:
|
||||
def _is_allowed_user(self, user_id: str) -> bool:
|
||||
"""Check if user is in DISCORD_ALLOWED_USERS."""
|
||||
if not self._allowed_user_ids:
|
||||
return True
|
||||
# Check user ID allowlist
|
||||
if has_users and user_id in allowed_users:
|
||||
return True
|
||||
# Check role allowlist
|
||||
if has_roles:
|
||||
# Try direct role check from Member object
|
||||
direct_roles = getattr(author, "roles", None) if author is not None else None
|
||||
if direct_roles:
|
||||
if any(getattr(r, "id", None) in allowed_roles for r in direct_roles):
|
||||
return True
|
||||
# Fallback: scan mutual guilds for member's roles
|
||||
if self._client is not None:
|
||||
try:
|
||||
uid_int = int(user_id)
|
||||
except (TypeError, ValueError):
|
||||
uid_int = None
|
||||
if uid_int is not None:
|
||||
for guild in self._client.guilds:
|
||||
m = guild.get_member(uid_int)
|
||||
if m is None:
|
||||
continue
|
||||
m_roles = getattr(m, "roles", None) or []
|
||||
if any(getattr(r, "id", None) in allowed_roles for r in m_roles):
|
||||
return True
|
||||
return False
|
||||
return user_id in self._allowed_user_ids
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
@@ -2009,23 +1904,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._register_skill_group(tree)
|
||||
|
||||
def _register_skill_group(self, tree) -> None:
|
||||
"""Register a single ``/skill`` command with autocomplete on the name.
|
||||
"""Register a ``/skill`` command group with category subcommand groups.
|
||||
|
||||
Discord enforces an ~8000-byte per-command payload limit. The older
|
||||
nested layout (``/skill <category> <name>``) registered one giant
|
||||
command whose serialized payload grew linearly with the skill
|
||||
catalog — with the default ~75 skills the payload was ~14 KB and
|
||||
``tree.sync()`` rejected the entire slash-command batch (issues
|
||||
#11321, #10259, #11385, #10261, #10214).
|
||||
|
||||
Autocomplete options are fetched dynamically by Discord when the
|
||||
user types — they do NOT count against the per-command registration
|
||||
budget. So we register ONE flat ``/skill`` command with
|
||||
``name: str`` (autocompleted) and ``args: str = ""``. This scales
|
||||
to thousands of skills with no size math, no splitting, and no
|
||||
hidden skills. The slash picker also becomes more discoverable —
|
||||
Discord live-filters by the user's typed prefix against both the
|
||||
skill name and its description.
|
||||
Skills are organized by their directory category under ``SKILLS_DIR``.
|
||||
Each category becomes a subcommand group; root-level skills become
|
||||
direct subcommands. Discord supports 25 subcommand groups × 25
|
||||
subcommands each = 625 skills — well beyond the old 100-command cap.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.commands import discord_skill_commands_by_category
|
||||
@@ -2036,97 +1920,68 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reuse the existing collector for consistent filtering
|
||||
# (per-platform disabled, hub-excluded, name clamping), then
|
||||
# flatten — the category grouping was only useful for the
|
||||
# nested layout.
|
||||
categories, uncategorized, hidden = discord_skill_commands_by_category(
|
||||
reserved_names=existing_names,
|
||||
)
|
||||
entries: list[tuple[str, str, str]] = list(uncategorized)
|
||||
for cat_skills in categories.values():
|
||||
entries.extend(cat_skills)
|
||||
|
||||
if not entries:
|
||||
if not categories and not uncategorized:
|
||||
return
|
||||
|
||||
# Stable alphabetical order so the autocomplete suggestion
|
||||
# list is predictable across restarts.
|
||||
entries.sort(key=lambda t: t[0])
|
||||
|
||||
# name -> (description, cmd_key) — used by both the autocomplete
|
||||
# callback and the handler for O(1) dispatch.
|
||||
skill_lookup: dict[str, tuple[str, str]] = {
|
||||
n: (d, k) for n, d, k in entries
|
||||
}
|
||||
|
||||
async def _autocomplete_name(
|
||||
interaction: "discord.Interaction", current: str,
|
||||
) -> list:
|
||||
"""Filter skills by the user's typed prefix.
|
||||
|
||||
Matches both the skill name and its description so
|
||||
"/skill pdf" surfaces skills whose description mentions
|
||||
PDFs even if the name doesn't. Discord caps this list at
|
||||
25 entries per query.
|
||||
"""
|
||||
q = (current or "").strip().lower()
|
||||
choices: list = []
|
||||
for name, desc, _key in entries:
|
||||
if not q or q in name.lower() or (desc and q in desc.lower()):
|
||||
if desc:
|
||||
label = f"{name} — {desc}"
|
||||
else:
|
||||
label = name
|
||||
# Discord's Choice.name is capped at 100 chars.
|
||||
if len(label) > 100:
|
||||
label = label[:97] + "..."
|
||||
choices.append(
|
||||
discord.app_commands.Choice(name=label, value=name)
|
||||
)
|
||||
if len(choices) >= 25:
|
||||
break
|
||||
return choices
|
||||
|
||||
@discord.app_commands.describe(
|
||||
name="Which skill to run",
|
||||
args="Optional arguments for the skill",
|
||||
)
|
||||
@discord.app_commands.autocomplete(name=_autocomplete_name)
|
||||
async def _skill_handler(
|
||||
interaction: "discord.Interaction", name: str, args: str = "",
|
||||
):
|
||||
entry = skill_lookup.get(name)
|
||||
if not entry:
|
||||
await interaction.response.send_message(
|
||||
f"Unknown skill: `{name}`. Start typing for "
|
||||
f"autocomplete suggestions.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
_desc, cmd_key = entry
|
||||
await self._run_simple_slash(
|
||||
interaction, f"{cmd_key} {args}".strip()
|
||||
)
|
||||
|
||||
cmd = discord.app_commands.Command(
|
||||
skill_group = discord.app_commands.Group(
|
||||
name="skill",
|
||||
description="Run a Hermes skill",
|
||||
callback=_skill_handler,
|
||||
)
|
||||
tree.add_command(cmd)
|
||||
|
||||
# ── Helper: build a callback for a skill command key ──
|
||||
def _make_handler(_key: str):
|
||||
@discord.app_commands.describe(args="Optional arguments for the skill")
|
||||
async def _handler(interaction: discord.Interaction, args: str = ""):
|
||||
await self._run_simple_slash(interaction, f"{_key} {args}".strip())
|
||||
_handler.__name__ = f"skill_{_key.lstrip('/').replace('-', '_')}"
|
||||
return _handler
|
||||
|
||||
# ── Uncategorized (root-level) skills → direct subcommands ──
|
||||
for discord_name, description, cmd_key in uncategorized:
|
||||
cmd = discord.app_commands.Command(
|
||||
name=discord_name,
|
||||
description=description or f"Run the {discord_name} skill",
|
||||
callback=_make_handler(cmd_key),
|
||||
)
|
||||
skill_group.add_command(cmd)
|
||||
|
||||
# ── Category subcommand groups ──
|
||||
for cat_name in sorted(categories):
|
||||
cat_desc = f"{cat_name.replace('-', ' ').title()} skills"
|
||||
if len(cat_desc) > 100:
|
||||
cat_desc = cat_desc[:97] + "..."
|
||||
cat_group = discord.app_commands.Group(
|
||||
name=cat_name,
|
||||
description=cat_desc,
|
||||
parent=skill_group,
|
||||
)
|
||||
for discord_name, description, cmd_key in categories[cat_name]:
|
||||
cmd = discord.app_commands.Command(
|
||||
name=discord_name,
|
||||
description=description or f"Run the {discord_name} skill",
|
||||
callback=_make_handler(cmd_key),
|
||||
)
|
||||
cat_group.add_command(cmd)
|
||||
|
||||
tree.add_command(skill_group)
|
||||
|
||||
total = sum(len(v) for v in categories.values()) + len(uncategorized)
|
||||
logger.info(
|
||||
"[%s] Registered /skill command with %d skill(s) via autocomplete",
|
||||
self.name, len(entries),
|
||||
"[%s] Registered /skill group: %d skill(s) across %d categories"
|
||||
" + %d uncategorized",
|
||||
self.name, total, len(categories), len(uncategorized),
|
||||
)
|
||||
if hidden:
|
||||
logger.info(
|
||||
"[%s] %d skill(s) filtered out of /skill (name clamp / reserved)",
|
||||
logger.warning(
|
||||
"[%s] %d skill(s) not registered (Discord subcommand limits)",
|
||||
self.name, hidden,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] Failed to register /skill command: %s", self.name, exc)
|
||||
logger.warning("[%s] Failed to register /skill group: %s", self.name, exc)
|
||||
|
||||
def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Discord slash command interaction."""
|
||||
@@ -2285,26 +2140,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
from gateway.platforms.base import resolve_channel_prompt
|
||||
return resolve_channel_prompt(self.config.extra, channel_id, parent_id)
|
||||
|
||||
def _discord_require_mention(self) -> bool:
|
||||
"""Return whether Discord channel messages require a bot mention."""
|
||||
configured = self.config.extra.get("require_mention")
|
||||
if configured is not None:
|
||||
if isinstance(configured, str):
|
||||
return configured.lower() not in ("false", "0", "no", "off")
|
||||
return bool(configured)
|
||||
return os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no", "off")
|
||||
|
||||
def _discord_free_response_channels(self) -> set:
|
||||
"""Return Discord channel IDs where no bot mention is required."""
|
||||
raw = self.config.extra.get("free_response_channels")
|
||||
if raw is None:
|
||||
raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
if isinstance(raw, list):
|
||||
return {str(part).strip() for part in raw if str(part).strip()}
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
return {part.strip() for part in raw.split(",") if part.strip()}
|
||||
return set()
|
||||
|
||||
def _thread_parent_channel(self, channel: Any) -> Any:
|
||||
"""Return the parent text channel when invoked from a thread."""
|
||||
return getattr(channel, "parent", None) or channel
|
||||
@@ -2407,15 +2242,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
Returns the created thread object, or ``None`` on failure.
|
||||
"""
|
||||
# Build a short thread name from the message. Strip Discord mention
|
||||
# syntax (users / roles / channels) so thread titles don't end up
|
||||
# showing raw <@id>, <@&id>, or <#id> markers — the ID isn't
|
||||
# meaningful to humans glancing at the thread list (#6336).
|
||||
# Build a short thread name from the message
|
||||
content = (message.content or "").strip()
|
||||
# <@123>, <@!123>, <@&123>, <#123> — collapse to empty; normalize spaces.
|
||||
content = re.sub(r"<@[!&]?\d+>", "", content)
|
||||
content = re.sub(r"<#\d+>", "", content)
|
||||
content = re.sub(r"\s+", " ", content).strip()
|
||||
thread_name = content[:80] if content else "Hermes"
|
||||
if len(content) > 80:
|
||||
thread_name = thread_name[:77] + "..."
|
||||
@@ -2423,25 +2251,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
thread = await message.create_thread(name=thread_name, auto_archive_duration=1440)
|
||||
return thread
|
||||
except Exception as direct_error:
|
||||
display_name = getattr(getattr(message, "author", None), "display_name", None) or "unknown user"
|
||||
reason = f"Auto-threaded from mention by {display_name}"
|
||||
try:
|
||||
seed_msg = await message.channel.send(f"\U0001f9f5 Thread created by Hermes: **{thread_name}**")
|
||||
thread = await seed_msg.create_thread(
|
||||
name=thread_name,
|
||||
auto_archive_duration=1440,
|
||||
reason=reason,
|
||||
)
|
||||
return thread
|
||||
except Exception as fallback_error:
|
||||
logger.warning(
|
||||
"[%s] Auto-thread creation failed. Direct error: %s. Fallback error: %s",
|
||||
self.name,
|
||||
direct_error,
|
||||
fallback_error,
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Auto-thread creation failed: %s", self.name, e)
|
||||
return None
|
||||
|
||||
async def send_exec_approval(
|
||||
self, chat_id: str, command: str, session_key: str,
|
||||
@@ -2628,124 +2440,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return f"{parent_name} / {thread_name}"
|
||||
return thread_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Attachment download helpers
|
||||
#
|
||||
# Discord attachments (images / audio / documents) are fetched via the
|
||||
# authenticated bot session whenever the Attachment object exposes
|
||||
# ``read()``. That sidesteps two classes of bug that hit the older
|
||||
# plain-HTTP path:
|
||||
#
|
||||
# 1. ``cdn.discordapp.com`` URLs increasingly require bot auth on
|
||||
# download — unauthenticated httpx sees 403 Forbidden.
|
||||
# (issue #8242)
|
||||
# 2. Some user environments (VPNs, corporate DNS, tunnels) resolve
|
||||
# ``cdn.discordapp.com`` to private-looking IPs that our
|
||||
# ``is_safe_url`` guard classifies as SSRF risks. Routing the
|
||||
# fetch through discord.py's own HTTP client handles DNS
|
||||
# internally so our guard isn't consulted for the attachment
|
||||
# path. (issue #6587)
|
||||
#
|
||||
# If ``att.read()`` is unavailable (unexpected object shape / test
|
||||
# stub) or the bot session fetch fails, we fall back to the existing
|
||||
# SSRF-gated URL downloaders. The fallback keeps defense-in-depth
|
||||
# against any future Discord payload-schema drift that could slip a
|
||||
# non-CDN URL into the ``att.url`` field. (issue #11345)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _read_attachment_bytes(self, att) -> Optional[bytes]:
|
||||
"""Read an attachment via discord.py's authenticated bot session.
|
||||
|
||||
Returns the raw bytes on success, or ``None`` if ``att`` doesn't
|
||||
expose a callable ``read()`` or the read itself fails. Callers
|
||||
should treat ``None`` as a signal to fall back to the URL-based
|
||||
downloaders.
|
||||
"""
|
||||
reader = getattr(att, "read", None)
|
||||
if reader is None or not callable(reader):
|
||||
return None
|
||||
try:
|
||||
return await reader()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Discord] Authenticated attachment read failed for %s: %s",
|
||||
getattr(att, "filename", None) or getattr(att, "url", "<unknown>"),
|
||||
e,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _cache_discord_image(self, att, ext: str) -> str:
|
||||
"""Cache a Discord image attachment to local disk.
|
||||
|
||||
Primary path: ``att.read()`` + ``cache_image_from_bytes``
|
||||
(authenticated, no SSRF gate).
|
||||
|
||||
Fallback: ``cache_image_from_url`` (plain httpx, SSRF-gated).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
if raw_bytes is not None:
|
||||
try:
|
||||
return cache_image_from_bytes(raw_bytes, ext=ext)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"[Discord] cache_image_from_bytes rejected att.read() data; falling back to URL: %s",
|
||||
e,
|
||||
)
|
||||
return await cache_image_from_url(att.url, ext=ext)
|
||||
|
||||
async def _cache_discord_audio(self, att, ext: str) -> str:
|
||||
"""Cache a Discord audio attachment to local disk.
|
||||
|
||||
Primary path: ``att.read()`` + ``cache_audio_from_bytes``
|
||||
(authenticated, no SSRF gate).
|
||||
|
||||
Fallback: ``cache_audio_from_url`` (plain httpx, SSRF-gated).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
if raw_bytes is not None:
|
||||
try:
|
||||
return cache_audio_from_bytes(raw_bytes, ext=ext)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"[Discord] cache_audio_from_bytes failed; falling back to URL: %s",
|
||||
e,
|
||||
)
|
||||
return await cache_audio_from_url(att.url, ext=ext)
|
||||
|
||||
async def _cache_discord_document(self, att, ext: str) -> bytes:
|
||||
"""Download a Discord document attachment and return the raw bytes.
|
||||
|
||||
Primary path: ``att.read()`` (authenticated, no SSRF gate).
|
||||
|
||||
Fallback: SSRF-gated ``aiohttp`` download. This closes the gap
|
||||
where the old document path made raw ``aiohttp.ClientSession``
|
||||
requests with no safety check (#11345). The caller is responsible
|
||||
for passing the returned bytes to ``cache_document_from_bytes``
|
||||
(and, where applicable, for injecting text content).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
if raw_bytes is not None:
|
||||
return raw_bytes
|
||||
|
||||
# Fallback: SSRF-gated URL download.
|
||||
if not is_safe_url(att.url):
|
||||
raise ValueError(
|
||||
f"Blocked unsafe attachment URL (SSRF protection): {att.url}"
|
||||
)
|
||||
import aiohttp
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
**_req_kw,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
return await resp.read()
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
@@ -2788,11 +2482,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_ids)
|
||||
return
|
||||
|
||||
free_channels = self._discord_free_response_channels()
|
||||
free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
|
||||
if parent_channel_id:
|
||||
channel_ids.add(parent_channel_id)
|
||||
|
||||
require_mention = self._discord_require_mention()
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
# Voice-linked text channels act as free-response while voice is active.
|
||||
# Only the exact bound channel gets the exemption, not sibling threads.
|
||||
voice_linked_ids = {str(ch_id) for ch_id in self._voice_text_channels.values()}
|
||||
@@ -2820,10 +2515,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not is_thread and not isinstance(message.channel, discord.DMChannel):
|
||||
no_thread_channels_raw = os.getenv("DISCORD_NO_THREAD_CHANNELS", "")
|
||||
no_thread_channels = {ch.strip() for ch in no_thread_channels_raw.split(",") if ch.strip()}
|
||||
skip_thread = bool(channel_ids & no_thread_channels) or is_free_channel
|
||||
skip_thread = bool(channel_ids & no_thread_channels)
|
||||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in ("true", "1", "yes")
|
||||
is_reply_message = getattr(message, "type", None) == discord.MessageType.reply
|
||||
if auto_thread and not skip_thread and not is_voice_linked_channel and not is_reply_message:
|
||||
if auto_thread and not skip_thread and not is_voice_linked_channel:
|
||||
thread = await self._auto_create_thread(message)
|
||||
if thread:
|
||||
is_thread = True
|
||||
@@ -2884,7 +2578,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
user_name=message.author.display_name,
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
is_bot=getattr(message.author, "bot", False),
|
||||
)
|
||||
|
||||
# Build media URLs -- download image attachments to local cache so the
|
||||
@@ -2900,7 +2593,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||||
if ext not in (".jpg", ".jpeg", ".png", ".gif", ".webp"):
|
||||
ext = ".jpg"
|
||||
cached_path = await self._cache_discord_image(att, ext)
|
||||
cached_path = await cache_image_from_url(att.url, ext=ext)
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(content_type)
|
||||
print(f"[Discord] Cached user image: {cached_path}", flush=True)
|
||||
@@ -2914,7 +2607,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||||
if ext not in (".ogg", ".mp3", ".wav", ".webm", ".m4a"):
|
||||
ext = ".ogg"
|
||||
cached_path = await self._cache_discord_audio(att, ext)
|
||||
cached_path = await cache_audio_from_url(att.url, ext=ext)
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(content_type)
|
||||
print(f"[Discord] Cached user audio: {cached_path}", flush=True)
|
||||
@@ -2945,7 +2638,19 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
)
|
||||
else:
|
||||
try:
|
||||
raw_bytes = await self._cache_discord_document(att, ext)
|
||||
import aiohttp
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
**_req_kw,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
raw_bytes = await resp.read()
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, att.filename or f"document{ext}"
|
||||
)
|
||||
|
||||
@@ -64,7 +64,6 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
_ssrf_redirect_guard,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_bytes,
|
||||
)
|
||||
@@ -227,11 +226,7 @@ class QQAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
try:
|
||||
self._http_client = httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
event_hooks={"response": [_ssrf_redirect_guard]},
|
||||
)
|
||||
self._http_client = httpx.AsyncClient(timeout=30.0, follow_redirects=True)
|
||||
|
||||
# 1. Get access token
|
||||
await self._ensure_token()
|
||||
@@ -1106,11 +1101,6 @@ class QQAdapter(BasePlatformAdapter):
|
||||
is_pre_wav = True
|
||||
logger.info("[QQ] STT: using voice_wav_url (pre-converted WAV)")
|
||||
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(download_url):
|
||||
logger.warning("[QQ] STT blocked unsafe URL: %s", download_url[:80])
|
||||
return None
|
||||
|
||||
try:
|
||||
# 2. Download audio (QQ CDN requires Authorization header)
|
||||
if not self._http_client:
|
||||
@@ -1535,33 +1525,6 @@ class QQAdapter(BasePlatformAdapter):
|
||||
|
||||
raise last_exc # type: ignore[misc]
|
||||
|
||||
# Maximum time (seconds) to wait for reconnection before giving up on send.
|
||||
_RECONNECT_WAIT_SECONDS = 15.0
|
||||
# How often (seconds) to poll is_connected while waiting.
|
||||
_RECONNECT_POLL_INTERVAL = 0.5
|
||||
|
||||
async def _wait_for_reconnection(self) -> bool:
|
||||
"""Wait for the WebSocket listener to reconnect.
|
||||
|
||||
The listener loop (_listen_loop) auto-reconnects on disconnect, but
|
||||
there is a race window where send() is called right after a disconnect
|
||||
and before the reconnect completes. This method polls is_connected
|
||||
for up to _RECONNECT_WAIT_SECONDS.
|
||||
|
||||
Returns True if reconnected, False if still disconnected.
|
||||
"""
|
||||
logger.info("[%s] Not connected — waiting for reconnection (up to %.0fs)",
|
||||
self.name, self._RECONNECT_WAIT_SECONDS)
|
||||
waited = 0.0
|
||||
while waited < self._RECONNECT_WAIT_SECONDS:
|
||||
await asyncio.sleep(self._RECONNECT_POLL_INTERVAL)
|
||||
waited += self._RECONNECT_POLL_INTERVAL
|
||||
if self.is_connected:
|
||||
logger.info("[%s] Reconnected after %.1fs", self.name, waited)
|
||||
return True
|
||||
logger.warning("[%s] Still not connected after %.0fs", self.name, self._RECONNECT_WAIT_SECONDS)
|
||||
return False
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -1577,8 +1540,7 @@ class QQAdapter(BasePlatformAdapter):
|
||||
del metadata
|
||||
|
||||
if not self.is_connected:
|
||||
if not await self._wait_for_reconnection():
|
||||
return SendResult(success=False, error="Not connected", retryable=True)
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not content or not content.strip():
|
||||
return SendResult(success=True)
|
||||
@@ -1779,8 +1741,7 @@ class QQAdapter(BasePlatformAdapter):
|
||||
) -> SendResult:
|
||||
"""Upload media and send as a native message."""
|
||||
if not self.is_connected:
|
||||
if not await self._wait_for_reconnection():
|
||||
return SendResult(success=False, error="Not connected", retryable=True)
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Resolve media source
|
||||
|
||||
+86
-310
@@ -28,7 +28,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import quote, urlparse
|
||||
from urllib.parse import quote
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,28 +96,6 @@ MEDIA_VIDEO = 2
|
||||
MEDIA_FILE = 3
|
||||
MEDIA_VOICE = 4
|
||||
|
||||
_LIVE_ADAPTERS: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _make_ssl_connector() -> Optional["aiohttp.TCPConnector"]:
|
||||
"""Return a TCPConnector with a certifi CA bundle, or None if certifi is unavailable.
|
||||
|
||||
Tencent's iLink server (``ilinkai.weixin.qq.com``) is not verifiable against
|
||||
some system CA stores (notably Homebrew's OpenSSL on macOS Apple Silicon).
|
||||
When ``certifi`` is installed, use its Mozilla CA bundle to guarantee
|
||||
verification. Otherwise fall back to aiohttp's default (which honors
|
||||
``SSL_CERT_FILE`` env var via ``trust_env=True``).
|
||||
"""
|
||||
try:
|
||||
import ssl
|
||||
import certifi
|
||||
except ImportError:
|
||||
return None
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
return None
|
||||
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
||||
return aiohttp.TCPConnector(ssl=ssl_ctx)
|
||||
|
||||
ITEM_TEXT = 1
|
||||
ITEM_IMAGE = 2
|
||||
ITEM_VOICE = 3
|
||||
@@ -420,12 +398,7 @@ async def _send_message(
|
||||
text: str,
|
||||
context_token: Optional[str],
|
||||
client_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a text message via iLink sendmessage API.
|
||||
|
||||
Returns the raw API response dict (may contain error codes like
|
||||
``errcode: -14`` for session expiry that the caller can inspect).
|
||||
"""
|
||||
) -> None:
|
||||
if not text or not text.strip():
|
||||
raise ValueError("_send_message: text must not be empty")
|
||||
message: Dict[str, Any] = {
|
||||
@@ -438,7 +411,7 @@ async def _send_message(
|
||||
}
|
||||
if context_token:
|
||||
message["context_token"] = context_token
|
||||
return await _api_post(
|
||||
await _api_post(
|
||||
session,
|
||||
base_url=base_url,
|
||||
endpoint=EP_SEND_MESSAGE,
|
||||
@@ -560,39 +533,6 @@ async def _download_bytes(
|
||||
return await response.read()
|
||||
|
||||
|
||||
_WEIXIN_CDN_ALLOWLIST: frozenset[str] = frozenset(
|
||||
{
|
||||
"novac2c.cdn.weixin.qq.com",
|
||||
"ilinkai.weixin.qq.com",
|
||||
"wx.qlogo.cn",
|
||||
"thirdwx.qlogo.cn",
|
||||
"res.wx.qq.com",
|
||||
"mmbiz.qpic.cn",
|
||||
"mmbiz.qlogo.cn",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _assert_weixin_cdn_url(url: str) -> None:
|
||||
"""Raise ValueError if *url* does not point at a known WeChat CDN host."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
scheme = parsed.scheme.lower()
|
||||
host = parsed.hostname or ""
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise ValueError(f"Unparseable media URL: {url!r}") from exc
|
||||
|
||||
if scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Media URL has disallowed scheme {scheme!r}; only http/https are permitted."
|
||||
)
|
||||
if host not in _WEIXIN_CDN_ALLOWLIST:
|
||||
raise ValueError(
|
||||
f"Media URL host {host!r} is not in the WeChat CDN allowlist. "
|
||||
"Refusing to fetch to prevent SSRF."
|
||||
)
|
||||
|
||||
|
||||
def _media_reference(item: Dict[str, Any], key: str) -> Dict[str, Any]:
|
||||
return (item.get(key) or {}).get("media") or {}
|
||||
|
||||
@@ -613,7 +553,6 @@ async def _download_and_decrypt_media(
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
elif full_url:
|
||||
_assert_weixin_cdn_url(full_url)
|
||||
raw = await _download_bytes(session, url=full_url, timeout_seconds=timeout_seconds)
|
||||
else:
|
||||
raise RuntimeError("media item had neither encrypt_query_param nor full_url")
|
||||
@@ -684,31 +623,42 @@ def _rewrite_table_block_for_weixin(lines: List[str]) -> str:
|
||||
def _normalize_markdown_blocks(content: str) -> str:
|
||||
lines = content.splitlines()
|
||||
result: List[str] = []
|
||||
i = 0
|
||||
in_code_block = False
|
||||
blank_run = 0
|
||||
|
||||
for raw_line in lines:
|
||||
line = raw_line.rstrip()
|
||||
if _FENCE_RE.match(line.strip()):
|
||||
while i < len(lines):
|
||||
line = lines[i].rstrip()
|
||||
fence_match = _FENCE_RE.match(line.strip())
|
||||
if fence_match:
|
||||
in_code_block = not in_code_block
|
||||
result.append(line)
|
||||
blank_run = 0
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_code_block:
|
||||
result.append(line)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if not line.strip():
|
||||
blank_run += 1
|
||||
if blank_run <= 1:
|
||||
result.append("")
|
||||
if (
|
||||
i + 1 < len(lines)
|
||||
and "|" in lines[i]
|
||||
and _TABLE_RULE_RE.match(lines[i + 1].rstrip())
|
||||
):
|
||||
table_lines = [lines[i].rstrip(), lines[i + 1].rstrip()]
|
||||
i += 2
|
||||
while i < len(lines) and "|" in lines[i]:
|
||||
table_lines.append(lines[i].rstrip())
|
||||
i += 1
|
||||
result.append(_rewrite_table_block_for_weixin(table_lines))
|
||||
continue
|
||||
|
||||
blank_run = 0
|
||||
result.append(line)
|
||||
result.append(_MARKDOWN_LINK_RE.sub(r"\1 (\2)", _rewrite_headers_for_weixin(line)))
|
||||
i += 1
|
||||
|
||||
return "\n".join(result).strip()
|
||||
normalized = "\n".join(item.rstrip() for item in result)
|
||||
normalized = re.sub(r"\n{3,}", "\n\n", normalized)
|
||||
return normalized.strip()
|
||||
|
||||
|
||||
def _split_markdown_blocks(content: str) -> List[str]:
|
||||
@@ -754,8 +704,8 @@ def _split_delivery_units_for_weixin(content: str) -> List[str]:
|
||||
|
||||
Weixin can render Markdown, but chat readability is better when top-level
|
||||
line breaks become separate messages. Keep fenced code blocks intact and
|
||||
attach indented continuation lines to the previous top-level line so nested
|
||||
list items do not get torn apart.
|
||||
attach indented continuation lines to the previous top-level line so
|
||||
transformed tables/lists do not get torn apart.
|
||||
"""
|
||||
units: List[str] = []
|
||||
|
||||
@@ -797,9 +747,7 @@ def _looks_like_chatty_line_for_weixin(line: str) -> bool:
|
||||
return False
|
||||
if line.startswith((" ", "\t")):
|
||||
return False
|
||||
if stripped.startswith((">", "-", "*", "【", "#", "|")):
|
||||
return False
|
||||
if _TABLE_RULE_RE.match(stripped):
|
||||
if stripped.startswith((">", "-", "*", "【")):
|
||||
return False
|
||||
if re.match(r"^\*\*[^*]+\*\*$", stripped):
|
||||
return False
|
||||
@@ -809,12 +757,10 @@ def _looks_like_chatty_line_for_weixin(line: str) -> bool:
|
||||
|
||||
|
||||
def _looks_like_heading_line_for_weixin(line: str) -> bool:
|
||||
"""Return True when a short line behaves like a heading."""
|
||||
"""Return True when a short line behaves like a plain-text heading."""
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
return False
|
||||
if _HEADER_RE.match(stripped):
|
||||
return True
|
||||
return len(stripped) <= 24 and stripped.endswith((":", ":"))
|
||||
|
||||
|
||||
@@ -989,7 +935,7 @@ async def qr_login(
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError("aiohttp is required for Weixin QR login")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True, connector=_make_ssl_connector()) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
try:
|
||||
qr_resp = await _api_get(
|
||||
session,
|
||||
@@ -1007,10 +953,6 @@ async def qr_login(
|
||||
logger.error("weixin: QR response missing qrcode")
|
||||
return None
|
||||
|
||||
# qrcode_url is the full scannable liteapp URL; qrcode_value is just the hex token
|
||||
# WeChat needs to scan the full URL, not the raw hex string
|
||||
qr_scan_data = qrcode_url if qrcode_url else qrcode_value
|
||||
|
||||
print("\n请使用微信扫描以下二维码:")
|
||||
if qrcode_url:
|
||||
print(qrcode_url)
|
||||
@@ -1018,11 +960,11 @@ async def qr_login(
|
||||
import qrcode
|
||||
|
||||
qr = qrcode.QRCode()
|
||||
qr.add_data(qr_scan_data)
|
||||
qr.add_data(qrcode_url or qrcode_value)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
except Exception as _qr_exc:
|
||||
print(f"(终端二维码渲染失败: {_qr_exc},请直接打开上面的二维码链接)")
|
||||
except Exception:
|
||||
print("(终端二维码渲染失败,请直接打开上面的二维码链接)")
|
||||
|
||||
deadline = time.time() + timeout_seconds
|
||||
current_base_url = ILINK_BASE_URL
|
||||
@@ -1068,17 +1010,8 @@ async def qr_login(
|
||||
)
|
||||
qrcode_value = str(qr_resp.get("qrcode") or "")
|
||||
qrcode_url = str(qr_resp.get("qrcode_img_content") or "")
|
||||
qr_scan_data = qrcode_url if qrcode_url else qrcode_value
|
||||
if qrcode_url:
|
||||
print(qrcode_url)
|
||||
try:
|
||||
import qrcode as _qrcode
|
||||
qr = _qrcode.QRCode()
|
||||
qr.add_data(qr_scan_data)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.error("weixin: QR refresh failed: %s", exc)
|
||||
return None
|
||||
@@ -1126,8 +1059,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
self._hermes_home = hermes_home
|
||||
self._token_store = ContextTokenStore(hermes_home)
|
||||
self._typing_cache = TypingTicketCache()
|
||||
self._poll_session: Optional[aiohttp.ClientSession] = None
|
||||
self._send_session: Optional[aiohttp.ClientSession] = None
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS)
|
||||
|
||||
@@ -1202,17 +1134,14 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
except Exception as exc:
|
||||
logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc)
|
||||
|
||||
self._poll_session = aiohttp.ClientSession(trust_env=True, connector=_make_ssl_connector())
|
||||
self._send_session = aiohttp.ClientSession(trust_env=True, connector=_make_ssl_connector())
|
||||
self._session = aiohttp.ClientSession(trust_env=True)
|
||||
self._token_store.restore(self._account_id)
|
||||
self._poll_task = asyncio.create_task(self._poll_loop(), name="weixin-poll")
|
||||
self._mark_connected()
|
||||
_LIVE_ADAPTERS[self._token] = self
|
||||
logger.info("[%s] Connected account=%s base=%s", self.name, _safe_id(self._account_id), self._base_url)
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
_LIVE_ADAPTERS.pop(self._token, None)
|
||||
self._running = False
|
||||
if self._poll_task and not self._poll_task.done():
|
||||
self._poll_task.cancel()
|
||||
@@ -1221,18 +1150,15 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._poll_task = None
|
||||
if self._poll_session and not self._poll_session.closed:
|
||||
await self._poll_session.close()
|
||||
self._poll_session = None
|
||||
if self._send_session and not self._send_session.closed:
|
||||
await self._send_session.close()
|
||||
self._send_session = None
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self._release_platform_lock()
|
||||
self._mark_disconnected()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
assert self._poll_session is not None
|
||||
assert self._session is not None
|
||||
sync_buf = _load_sync_buf(self._hermes_home, self._account_id)
|
||||
timeout_ms = LONG_POLL_TIMEOUT_MS
|
||||
consecutive_failures = 0
|
||||
@@ -1240,7 +1166,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
while self._running:
|
||||
try:
|
||||
response = await _get_updates(
|
||||
self._poll_session,
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
sync_buf=sync_buf,
|
||||
@@ -1297,7 +1223,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
logger.error("[%s] unhandled inbound error from=%s: %s", self.name, _safe_id(message.get("from_user_id")), exc, exc_info=True)
|
||||
|
||||
async def _process_message(self, message: Dict[str, Any]) -> None:
|
||||
assert self._poll_session is not None
|
||||
assert self._session is not None
|
||||
sender_id = str(message.get("from_user_id") or "").strip()
|
||||
if not sender_id:
|
||||
return
|
||||
@@ -1390,7 +1316,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
media = _media_reference(item, "image_item")
|
||||
try:
|
||||
data = await _download_and_decrypt_media(
|
||||
self._poll_session,
|
||||
self._session,
|
||||
cdn_base_url=self._cdn_base_url,
|
||||
encrypted_query_param=media.get("encrypt_query_param"),
|
||||
aes_key_b64=(item.get("image_item") or {}).get("aeskey")
|
||||
@@ -1408,7 +1334,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
media = _media_reference(item, "video_item")
|
||||
try:
|
||||
data = await _download_and_decrypt_media(
|
||||
self._poll_session,
|
||||
self._session,
|
||||
cdn_base_url=self._cdn_base_url,
|
||||
encrypted_query_param=media.get("encrypt_query_param"),
|
||||
aes_key_b64=media.get("aes_key"),
|
||||
@@ -1427,7 +1353,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
mime = _mime_from_filename(filename)
|
||||
try:
|
||||
data = await _download_and_decrypt_media(
|
||||
self._poll_session,
|
||||
self._session,
|
||||
cdn_base_url=self._cdn_base_url,
|
||||
encrypted_query_param=media.get("encrypt_query_param"),
|
||||
aes_key_b64=media.get("aes_key"),
|
||||
@@ -1446,7 +1372,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
return None
|
||||
try:
|
||||
data = await _download_and_decrypt_media(
|
||||
self._poll_session,
|
||||
self._session,
|
||||
cdn_base_url=self._cdn_base_url,
|
||||
encrypted_query_param=media.get("encrypt_query_param"),
|
||||
aes_key_b64=media.get("aes_key"),
|
||||
@@ -1459,13 +1385,13 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
return None
|
||||
|
||||
async def _maybe_fetch_typing_ticket(self, user_id: str, context_token: Optional[str]) -> None:
|
||||
if not self._poll_session or not self._token:
|
||||
if not self._session or not self._token:
|
||||
return
|
||||
if self._typing_cache.get(user_id):
|
||||
return
|
||||
try:
|
||||
response = await _get_config(
|
||||
self._poll_session,
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
user_id=user_id,
|
||||
@@ -1490,19 +1416,12 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
context_token: Optional[str],
|
||||
client_id: str,
|
||||
) -> None:
|
||||
"""Send a single text chunk with per-chunk retry and backoff.
|
||||
|
||||
On session-expired errors (errcode -14), automatically retries
|
||||
*without* ``context_token`` — iLink accepts tokenless sends as a
|
||||
degraded fallback, which keeps cron-initiated push messages working
|
||||
even when no user message has refreshed the session recently.
|
||||
"""
|
||||
"""Send a single text chunk with per-chunk retry and backoff."""
|
||||
last_error: Optional[Exception] = None
|
||||
retried_without_token = False
|
||||
for attempt in range(self._send_chunk_retries + 1):
|
||||
try:
|
||||
resp = await _send_message(
|
||||
self._send_session,
|
||||
await _send_message(
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to=chat_id,
|
||||
@@ -1510,31 +1429,6 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
context_token=context_token,
|
||||
client_id=client_id,
|
||||
)
|
||||
# Check iLink response for session-expired error
|
||||
if resp and isinstance(resp, dict):
|
||||
ret = resp.get("ret")
|
||||
errcode = resp.get("errcode")
|
||||
if (ret is not None and ret not in (0,)) or (errcode is not None and errcode not in (0,)):
|
||||
is_session_expired = (
|
||||
ret == SESSION_EXPIRED_ERRCODE
|
||||
or errcode == SESSION_EXPIRED_ERRCODE
|
||||
)
|
||||
# Session expired — strip token and retry once
|
||||
if is_session_expired and not retried_without_token and context_token:
|
||||
retried_without_token = True
|
||||
context_token = None
|
||||
self._token_store._cache.pop(
|
||||
self._token_store._key(self._account_id, chat_id), None
|
||||
)
|
||||
logger.warning(
|
||||
"[%s] session expired for %s; retrying without context_token",
|
||||
self.name, _safe_id(chat_id),
|
||||
)
|
||||
continue
|
||||
errmsg = resp.get("errmsg") or resp.get("msg") or "unknown error"
|
||||
raise RuntimeError(
|
||||
f"iLink sendmessage error: ret={ret} errcode={errcode} errmsg={errmsg}"
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
@@ -1562,48 +1456,12 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
if not self._send_session or not self._token:
|
||||
if not self._session or not self._token:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
context_token = self._token_store.get(self._account_id, chat_id)
|
||||
last_message_id: Optional[str] = None
|
||||
|
||||
# Extract MEDIA: tags and bare local file paths before text delivery.
|
||||
media_files, cleaned_content = self.extract_media(content)
|
||||
_, image_cleaned = self.extract_images(cleaned_content)
|
||||
local_files, final_content = self.extract_local_files(image_cleaned)
|
||||
|
||||
_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".3gp"}
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
|
||||
async def _deliver_media(path: str, is_voice: bool = False) -> None:
|
||||
ext = Path(path).suffix.lower()
|
||||
if is_voice or ext in _AUDIO_EXTS:
|
||||
await self.send_voice(chat_id=chat_id, audio_path=path, metadata=metadata)
|
||||
elif ext in _VIDEO_EXTS:
|
||||
await self.send_video(chat_id=chat_id, video_path=path, metadata=metadata)
|
||||
elif ext in _IMAGE_EXTS:
|
||||
await self.send_image_file(chat_id=chat_id, image_path=path, metadata=metadata)
|
||||
else:
|
||||
await self.send_document(chat_id=chat_id, file_path=path, metadata=metadata)
|
||||
|
||||
try:
|
||||
# Deliver extracted MEDIA: attachments first.
|
||||
for media_path, is_voice in media_files:
|
||||
try:
|
||||
await _deliver_media(media_path, is_voice)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] media delivery failed for %s: %s", self.name, media_path, exc)
|
||||
|
||||
# Deliver bare local file paths.
|
||||
for file_path in local_files:
|
||||
try:
|
||||
await _deliver_media(file_path, is_voice=False)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] local file delivery failed for %s: %s", self.name, file_path, exc)
|
||||
|
||||
# Deliver text content.
|
||||
chunks = [c for c in self._split_text(self.format_message(final_content)) if c and c.strip()]
|
||||
chunks = [c for c in self._split_text(self.format_message(content)) if c and c.strip()]
|
||||
for idx, chunk in enumerate(chunks):
|
||||
client_id = f"hermes-weixin-{uuid.uuid4().hex}"
|
||||
await self._send_text_chunk(
|
||||
@@ -1621,14 +1479,14 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
if not self._send_session or not self._token:
|
||||
if not self._session or not self._token:
|
||||
return
|
||||
typing_ticket = self._typing_cache.get(chat_id)
|
||||
if not typing_ticket:
|
||||
return
|
||||
try:
|
||||
await _send_typing(
|
||||
self._send_session,
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to_user_id=chat_id,
|
||||
@@ -1639,14 +1497,14 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
logger.debug("[%s] typing start failed for %s: %s", self.name, _safe_id(chat_id), exc)
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
if not self._send_session or not self._token:
|
||||
if not self._session or not self._token:
|
||||
return
|
||||
typing_ticket = self._typing_cache.get(chat_id)
|
||||
if not typing_ticket:
|
||||
return
|
||||
try:
|
||||
await _send_typing(
|
||||
self._send_session,
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to_user_id=chat_id,
|
||||
@@ -1684,35 +1542,24 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
path: str,
|
||||
caption: str = "",
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
del reply_to, kwargs
|
||||
return await self.send_document(
|
||||
chat_id=chat_id,
|
||||
file_path=image_path,
|
||||
caption=caption,
|
||||
metadata=metadata,
|
||||
)
|
||||
return await self.send_document(chat_id, file_path=path, caption=caption, metadata=metadata)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
del file_name, reply_to, metadata, kwargs
|
||||
if not self._send_session or not self._token:
|
||||
if not self._session or not self._token:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
message_id = await self._send_file(chat_id, file_path, caption or "")
|
||||
message_id = await self._send_file(chat_id, file_path, caption)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as exc:
|
||||
logger.error("[%s] send_document failed to=%s: %s", self.name, _safe_id(chat_id), exc)
|
||||
@@ -1726,7 +1573,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
if not self._send_session or not self._token:
|
||||
if not self._session or not self._token:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
message_id = await self._send_file(chat_id, video_path, caption or "")
|
||||
@@ -1743,24 +1590,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
if not self._send_session or not self._token:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
# Native outbound Weixin voice bubbles are not proven-working in the
|
||||
# upstream reference implementation. Prefer a reliable file attachment
|
||||
# fallback so users at least receive playable audio, even for .silk.
|
||||
fallback_caption = caption or "[voice message as attachment]"
|
||||
try:
|
||||
message_id = await self._send_file(
|
||||
chat_id,
|
||||
audio_path,
|
||||
fallback_caption,
|
||||
force_file_attachment=True,
|
||||
)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as exc:
|
||||
logger.error("[%s] send_voice failed to=%s: %s", self.name, _safe_id(chat_id), exc)
|
||||
return SendResult(success=False, error=str(exc))
|
||||
return await self.send_document(chat_id, audio_path, caption=caption or "", metadata=metadata)
|
||||
|
||||
async def _download_remote_media(self, url: str) -> str:
|
||||
from tools.url_safety import is_safe_url
|
||||
@@ -1768,8 +1598,8 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
if not is_safe_url(url):
|
||||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {url}")
|
||||
|
||||
assert self._send_session is not None
|
||||
async with self._send_session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
||||
assert self._session is not None
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.read()
|
||||
suffix = Path(url.split("?", 1)[0]).suffix or ".bin"
|
||||
@@ -1777,22 +1607,16 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
handle.write(data)
|
||||
return handle.name
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
path: str,
|
||||
caption: str,
|
||||
force_file_attachment: bool = False,
|
||||
) -> str:
|
||||
assert self._send_session is not None and self._token is not None
|
||||
async def _send_file(self, chat_id: str, path: str, caption: str) -> str:
|
||||
assert self._session is not None and self._token is not None
|
||||
plaintext = Path(path).read_bytes()
|
||||
media_type, item_builder = self._outbound_media_builder(path, force_file_attachment=force_file_attachment)
|
||||
media_type, item_builder = self._outbound_media_builder(path)
|
||||
filekey = secrets.token_hex(16)
|
||||
aes_key = secrets.token_bytes(16)
|
||||
rawsize = len(plaintext)
|
||||
rawfilemd5 = hashlib.md5(plaintext).hexdigest()
|
||||
upload_response = await _get_upload_url(
|
||||
self._send_session,
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to_user_id=chat_id,
|
||||
@@ -1818,34 +1642,30 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
raise RuntimeError(f"getUploadUrl returned neither upload_param nor upload_full_url: {upload_response}")
|
||||
|
||||
encrypted_query_param = await _upload_ciphertext(
|
||||
self._send_session,
|
||||
self._session,
|
||||
ciphertext=ciphertext,
|
||||
upload_url=upload_url,
|
||||
)
|
||||
|
||||
context_token = self._token_store.get(self._account_id, chat_id)
|
||||
# The iLink API expects aes_key as base64(hex_string), not base64(raw_bytes).
|
||||
# Sending base64(raw_bytes) causes images to show as grey boxes on the
|
||||
# receiver side because the decryption key doesn't match.
|
||||
aes_key_for_api = base64.b64encode(aes_key.hex().encode("ascii")).decode("ascii")
|
||||
item_kwargs = {
|
||||
"encrypt_query_param": encrypted_query_param,
|
||||
"aes_key_for_api": aes_key_for_api,
|
||||
"ciphertext_size": len(ciphertext),
|
||||
"plaintext_size": rawsize,
|
||||
"filename": Path(path).name,
|
||||
"rawfilemd5": rawfilemd5,
|
||||
}
|
||||
if media_type == MEDIA_VOICE and path.endswith(".silk"):
|
||||
item_kwargs["encode_type"] = 6
|
||||
item_kwargs["sample_rate"] = 24000
|
||||
item_kwargs["bits_per_sample"] = 16
|
||||
media_item = item_builder(**item_kwargs)
|
||||
media_item = item_builder(
|
||||
encrypt_query_param=encrypted_query_param,
|
||||
aes_key_for_api=aes_key_for_api,
|
||||
ciphertext_size=len(ciphertext),
|
||||
plaintext_size=rawsize,
|
||||
filename=Path(path).name,
|
||||
rawfilemd5=rawfilemd5,
|
||||
)
|
||||
|
||||
last_message_id = None
|
||||
if caption:
|
||||
last_message_id = f"hermes-weixin-{uuid.uuid4().hex}"
|
||||
await _send_message(
|
||||
self._send_session,
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to=chat_id,
|
||||
@@ -1856,7 +1676,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
|
||||
last_message_id = f"hermes-weixin-{uuid.uuid4().hex}"
|
||||
await _api_post(
|
||||
self._send_session,
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
endpoint=EP_SEND_MESSAGE,
|
||||
payload={
|
||||
@@ -1875,7 +1695,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
)
|
||||
return last_message_id
|
||||
|
||||
def _outbound_media_builder(self, path: str, force_file_attachment: bool = False):
|
||||
def _outbound_media_builder(self, path: str):
|
||||
mime = mimetypes.guess_type(path)[0] or "application/octet-stream"
|
||||
if mime.startswith("image/"):
|
||||
return MEDIA_IMAGE, lambda **kw: {
|
||||
@@ -1903,7 +1723,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
"video_md5": kw.get("rawfilemd5", ""),
|
||||
},
|
||||
}
|
||||
if path.endswith(".silk") and not force_file_attachment:
|
||||
if mime.startswith("audio/") or path.endswith(".silk"):
|
||||
return MEDIA_VOICE, lambda **kw: {
|
||||
"type": ITEM_VOICE,
|
||||
"voice_item": {
|
||||
@@ -1912,25 +1732,9 @@ class WeixinAdapter(BasePlatformAdapter):
|
||||
"aes_key": kw["aes_key_for_api"],
|
||||
"encrypt_type": 1,
|
||||
},
|
||||
"encode_type": kw.get("encode_type"),
|
||||
"bits_per_sample": kw.get("bits_per_sample"),
|
||||
"sample_rate": kw.get("sample_rate"),
|
||||
"playtime": kw.get("playtime", 0),
|
||||
},
|
||||
}
|
||||
if mime.startswith("audio/"):
|
||||
return MEDIA_FILE, lambda **kw: {
|
||||
"type": ITEM_FILE,
|
||||
"file_item": {
|
||||
"media": {
|
||||
"encrypt_query_param": kw["encrypt_query_param"],
|
||||
"aes_key": kw["aes_key_for_api"],
|
||||
"encrypt_type": 1,
|
||||
},
|
||||
"file_name": kw["filename"],
|
||||
"len": str(kw["plaintext_size"]),
|
||||
},
|
||||
}
|
||||
return MEDIA_FILE, lambda **kw: {
|
||||
"type": ITEM_FILE,
|
||||
"file_item": {
|
||||
@@ -1980,34 +1784,7 @@ async def send_weixin_direct(
|
||||
token_store.restore(account_id)
|
||||
context_token = token_store.get(account_id, chat_id)
|
||||
|
||||
live_adapter = _LIVE_ADAPTERS.get(resolved_token)
|
||||
send_session = getattr(live_adapter, '_send_session', None)
|
||||
if live_adapter is not None and send_session is not None and not send_session.closed:
|
||||
last_result: Optional[SendResult] = None
|
||||
cleaned = live_adapter.format_message(message)
|
||||
if cleaned:
|
||||
last_result = await live_adapter.send(chat_id, cleaned)
|
||||
if not last_result.success:
|
||||
return {"error": f"Weixin send failed: {last_result.error}"}
|
||||
|
||||
for media_path, _is_voice in media_files or []:
|
||||
ext = Path(media_path).suffix.lower()
|
||||
if ext in {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}:
|
||||
last_result = await live_adapter.send_image_file(chat_id, media_path)
|
||||
else:
|
||||
last_result = await live_adapter.send_document(chat_id, media_path)
|
||||
if not last_result.success:
|
||||
return {"error": f"Weixin media send failed: {last_result.error}"}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"platform": "weixin",
|
||||
"chat_id": chat_id,
|
||||
"message_id": last_result.message_id if last_result else None,
|
||||
"context_token_used": bool(context_token),
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True, connector=_make_ssl_connector()) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
adapter = WeixinAdapter(
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
@@ -2020,7 +1797,6 @@ async def send_weixin_direct(
|
||||
},
|
||||
)
|
||||
)
|
||||
adapter._send_session = session
|
||||
adapter._session = session
|
||||
adapter._token = resolved_token
|
||||
adapter._account_id = account_id
|
||||
|
||||
+4
-228
@@ -24,20 +24,11 @@ import signal
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextvars import copy_context
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Any, List
|
||||
|
||||
# --- Agent cache tuning ---------------------------------------------------
|
||||
# Bounds the per-session AIAgent cache to prevent unbounded growth in
|
||||
# long-lived gateways (each AIAgent holds LLM clients, tool schemas,
|
||||
# memory providers, etc.). LRU order + idle TTL eviction are enforced
|
||||
# from _enforce_agent_cache_cap() and _session_expiry_watcher() below.
|
||||
_AGENT_CACHE_MAX_SIZE = 128
|
||||
_AGENT_CACHE_IDLE_TTL_SECS = 3600.0 # evict agents idle for >1h
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSL certificate auto-detection for NixOS and other non-standard systems.
|
||||
# Must run BEFORE any HTTP library (discord, aiohttp, etc.) is imported.
|
||||
@@ -631,13 +622,8 @@ class GatewayRunner:
|
||||
# system prompt (including memory) every turn — breaking prefix cache
|
||||
# and costing ~10x more on providers with prompt caching (Anthropic).
|
||||
# Key: session_key, Value: (AIAgent, config_signature_str)
|
||||
#
|
||||
# OrderedDict so _enforce_agent_cache_cap() can pop the least-recently-
|
||||
# used entry (move_to_end() on cache hits, popitem(last=False) for
|
||||
# eviction). Hard cap via _AGENT_CACHE_MAX_SIZE, idle TTL enforced
|
||||
# from _session_expiry_watcher().
|
||||
import threading as _threading
|
||||
self._agent_cache: "OrderedDict[str, tuple]" = OrderedDict()
|
||||
self._agent_cache: Dict[str, tuple] = {}
|
||||
self._agent_cache_lock = _threading.Lock()
|
||||
|
||||
# Per-session model overrides from /model command.
|
||||
@@ -2116,11 +2102,6 @@ class GatewayRunner:
|
||||
_cached_agent = self._running_agents.get(key)
|
||||
if _cached_agent and _cached_agent is not _AGENT_PENDING_SENTINEL:
|
||||
self._cleanup_agent_resources(_cached_agent)
|
||||
# Drop the cache entry so the AIAgent (and its LLM
|
||||
# clients, tool schemas, memory provider refs) can
|
||||
# be garbage-collected. Otherwise the cache grows
|
||||
# unbounded across the gateway's lifetime.
|
||||
self._evict_cached_agent(key)
|
||||
# Mark as flushed and persist to disk so the flag
|
||||
# survives gateway restarts.
|
||||
with self.session_store._lock:
|
||||
@@ -2164,20 +2145,6 @@ class GatewayRunner:
|
||||
logger.info(
|
||||
"Session expiry done: %d flushed", _flushed,
|
||||
)
|
||||
|
||||
# Sweep agents that have been idle beyond the TTL regardless
|
||||
# of session reset policy. This catches sessions with very
|
||||
# long / "never" reset windows, whose cached AIAgents would
|
||||
# otherwise pin memory for the gateway's entire lifetime.
|
||||
try:
|
||||
_idle_evicted = self._sweep_idle_cached_agents()
|
||||
if _idle_evicted:
|
||||
logger.info(
|
||||
"Agent cache idle sweep: evicted %d agent(s)",
|
||||
_idle_evicted,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("Idle agent sweep failed: %s", _e)
|
||||
except Exception as e:
|
||||
logger.debug("Session expiry watcher error: %s", e)
|
||||
# Sleep in small increments so we can stop quickly
|
||||
@@ -2651,9 +2618,6 @@ class GatewayRunner:
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS",
|
||||
Platform.QQBOT: "QQ_ALLOWED_USERS",
|
||||
}
|
||||
platform_group_env_map = {
|
||||
Platform.QQBOT: "QQ_GROUP_ALLOWED_USERS",
|
||||
}
|
||||
platform_allow_all_map = {
|
||||
Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS",
|
||||
Platform.DISCORD: "DISCORD_ALLOW_ALL_USERS",
|
||||
@@ -2678,28 +2642,6 @@ class GatewayRunner:
|
||||
if platform_allow_all_var and os.getenv(platform_allow_all_var, "").lower() in ("true", "1", "yes"):
|
||||
return True
|
||||
|
||||
# Discord bot senders that passed the DISCORD_ALLOW_BOTS platform
|
||||
# filter are already authorized at the platform level — skip the
|
||||
# user allowlist. Without this, bot messages allowed by
|
||||
# DISCORD_ALLOW_BOTS=mentions/all would be rejected here with
|
||||
# "Unauthorized user" (fixes #4466).
|
||||
if source.platform == Platform.DISCORD and getattr(source, "is_bot", False):
|
||||
allow_bots = os.getenv("DISCORD_ALLOW_BOTS", "none").lower().strip()
|
||||
if allow_bots in ("mentions", "all"):
|
||||
return True
|
||||
|
||||
# Discord role-based access (DISCORD_ALLOWED_ROLES): the adapter's
|
||||
# on_message pre-filter already verified role membership — if the
|
||||
# message reached here, the user passed that check. Authorize
|
||||
# directly to avoid the "no allowlists configured" branch below
|
||||
# rejecting role-only setups where DISCORD_ALLOWED_USERS is empty
|
||||
# (issue #7871).
|
||||
if (
|
||||
source.platform == Platform.DISCORD
|
||||
and os.getenv("DISCORD_ALLOWED_ROLES", "").strip()
|
||||
):
|
||||
return True
|
||||
|
||||
# Check pairing store (always checked, regardless of allowlists)
|
||||
platform_name = source.platform.value if source.platform else ""
|
||||
if self.pairing_store.is_approved(platform_name, user_id):
|
||||
@@ -2707,23 +2649,12 @@ class GatewayRunner:
|
||||
|
||||
# Check platform-specific and global allowlists
|
||||
platform_allowlist = os.getenv(platform_env_map.get(source.platform, ""), "").strip()
|
||||
group_allowlist = ""
|
||||
if source.chat_type == "group":
|
||||
group_allowlist = os.getenv(platform_group_env_map.get(source.platform, ""), "").strip()
|
||||
global_allowlist = os.getenv("GATEWAY_ALLOWED_USERS", "").strip()
|
||||
|
||||
if not platform_allowlist and not group_allowlist and not global_allowlist:
|
||||
if not platform_allowlist and not global_allowlist:
|
||||
# No allowlists configured -- check global allow-all flag
|
||||
return os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")
|
||||
|
||||
# Some platforms authorize group traffic by chat ID rather than sender ID.
|
||||
if group_allowlist and source.chat_type == "group" and source.chat_id:
|
||||
allowed_group_ids = {
|
||||
chat_id.strip() for chat_id in group_allowlist.split(",") if chat_id.strip()
|
||||
}
|
||||
if "*" in allowed_group_ids or source.chat_id in allowed_group_ids:
|
||||
return True
|
||||
|
||||
# Check if user is in any allowlist
|
||||
allowed_ids = set()
|
||||
if platform_allowlist:
|
||||
@@ -5866,7 +5797,7 @@ class GatewayRunner:
|
||||
pass
|
||||
|
||||
# Send media files
|
||||
for media_path, _is_voice in (media_files or []):
|
||||
for media_path in (media_files or []):
|
||||
try:
|
||||
await adapter.send_document(
|
||||
chat_id=source.chat_id,
|
||||
@@ -6044,7 +5975,7 @@ class GatewayRunner:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for media_path, _is_voice in (media_files or []):
|
||||
for media_path in (media_files or []):
|
||||
try:
|
||||
await adapter.send_file(chat_id=source.chat_id, file_path=media_path)
|
||||
except Exception:
|
||||
@@ -7920,153 +7851,6 @@ class GatewayRunner:
|
||||
with _lock:
|
||||
self._agent_cache.pop(session_key, None)
|
||||
|
||||
def _release_evicted_agent_soft(self, agent: Any) -> None:
|
||||
"""Soft cleanup for cache-evicted agents — preserves session tool state.
|
||||
|
||||
Called from _enforce_agent_cache_cap and _sweep_idle_cached_agents.
|
||||
Distinct from _cleanup_agent_resources (full teardown) because a
|
||||
cache-evicted session may resume at any time — its terminal
|
||||
sandbox, browser daemon, and tracked bg processes must outlive
|
||||
the Python AIAgent instance so the next agent built for the
|
||||
same task_id inherits them.
|
||||
"""
|
||||
if agent is None:
|
||||
return
|
||||
try:
|
||||
if hasattr(agent, "release_clients"):
|
||||
agent.release_clients()
|
||||
else:
|
||||
# Older agent instance (shouldn't happen in practice) —
|
||||
# fall back to the legacy full-close path.
|
||||
self._cleanup_agent_resources(agent)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _enforce_agent_cache_cap(self) -> None:
|
||||
"""Evict oldest cached agents when cache exceeds _AGENT_CACHE_MAX_SIZE.
|
||||
|
||||
Must be called with _agent_cache_lock held. Resource cleanup
|
||||
(memory provider shutdown, tool resource close) is scheduled
|
||||
on a daemon thread so the caller doesn't block on slow teardown
|
||||
while holding the cache lock.
|
||||
|
||||
Agents currently in _running_agents are SKIPPED — their clients,
|
||||
terminal sandboxes, background processes, and child subagents
|
||||
are all in active use by the running turn. Evicting them would
|
||||
tear down those resources mid-turn and crash the request. If
|
||||
every candidate in the LRU order is active, we simply leave the
|
||||
cache over the cap; it will be re-checked on the next insert.
|
||||
"""
|
||||
_cache = getattr(self, "_agent_cache", None)
|
||||
if _cache is None:
|
||||
return
|
||||
# OrderedDict.popitem(last=False) pops oldest; plain dict lacks the
|
||||
# arg so skip enforcement if a test fixture swapped the cache type.
|
||||
if not hasattr(_cache, "move_to_end"):
|
||||
return
|
||||
|
||||
# Snapshot of agent instances that are actively mid-turn. Use id()
|
||||
# so the lookup is O(1) and doesn't depend on AIAgent.__eq__ (which
|
||||
# MagicMock overrides in tests).
|
||||
running_ids = {
|
||||
id(a)
|
||||
for a in getattr(self, "_running_agents", {}).values()
|
||||
if a is not None and a is not _AGENT_PENDING_SENTINEL
|
||||
}
|
||||
|
||||
# Walk LRU → MRU and evict excess-LRU entries that aren't mid-turn.
|
||||
# We only consider entries in the first (size - cap) LRU positions
|
||||
# as eviction candidates. If one of those slots is held by an
|
||||
# active agent, we SKIP it without compensating by evicting a
|
||||
# newer entry — that would penalise a freshly-inserted session
|
||||
# (which has no cache history to retain) while protecting an
|
||||
# already-cached long-running one. The cache may therefore stay
|
||||
# temporarily over cap; it will re-check on the next insert,
|
||||
# after active turns have finished.
|
||||
excess = max(0, len(_cache) - _AGENT_CACHE_MAX_SIZE)
|
||||
evict_plan: List[tuple] = [] # [(key, agent), ...]
|
||||
if excess > 0:
|
||||
ordered_keys = list(_cache.keys())
|
||||
for key in ordered_keys[:excess]:
|
||||
entry = _cache.get(key)
|
||||
agent = entry[0] if isinstance(entry, tuple) and entry else None
|
||||
if agent is not None and id(agent) in running_ids:
|
||||
continue # active mid-turn; don't evict, don't substitute
|
||||
evict_plan.append((key, agent))
|
||||
|
||||
for key, _ in evict_plan:
|
||||
_cache.pop(key, None)
|
||||
|
||||
remaining_over_cap = len(_cache) - _AGENT_CACHE_MAX_SIZE
|
||||
if remaining_over_cap > 0:
|
||||
logger.warning(
|
||||
"Agent cache over cap (%d > %d); %d excess slot(s) held by "
|
||||
"mid-turn agents — will re-check on next insert.",
|
||||
len(_cache), _AGENT_CACHE_MAX_SIZE, remaining_over_cap,
|
||||
)
|
||||
|
||||
for key, agent in evict_plan:
|
||||
logger.info(
|
||||
"Agent cache at cap; evicting LRU session=%s (cache_size=%d)",
|
||||
key, len(_cache),
|
||||
)
|
||||
if agent is not None:
|
||||
threading.Thread(
|
||||
target=self._release_evicted_agent_soft,
|
||||
args=(agent,),
|
||||
daemon=True,
|
||||
name=f"agent-cache-evict-{key[:24]}",
|
||||
).start()
|
||||
|
||||
def _sweep_idle_cached_agents(self) -> int:
|
||||
"""Evict cached agents whose AIAgent has been idle > _AGENT_CACHE_IDLE_TTL_SECS.
|
||||
|
||||
Safe to call from the session expiry watcher without holding the
|
||||
cache lock — acquires it internally. Returns the number of entries
|
||||
evicted. Resource cleanup is scheduled on daemon threads.
|
||||
|
||||
Agents currently in _running_agents are SKIPPED for the same reason
|
||||
as _enforce_agent_cache_cap: tearing down an active turn's clients
|
||||
mid-flight would crash the request.
|
||||
"""
|
||||
_cache = getattr(self, "_agent_cache", None)
|
||||
_lock = getattr(self, "_agent_cache_lock", None)
|
||||
if _cache is None or _lock is None:
|
||||
return 0
|
||||
now = time.time()
|
||||
to_evict: List[tuple] = []
|
||||
running_ids = {
|
||||
id(a)
|
||||
for a in getattr(self, "_running_agents", {}).values()
|
||||
if a is not None and a is not _AGENT_PENDING_SENTINEL
|
||||
}
|
||||
with _lock:
|
||||
for key, entry in list(_cache.items()):
|
||||
agent = entry[0] if isinstance(entry, tuple) and entry else None
|
||||
if agent is None:
|
||||
continue
|
||||
if id(agent) in running_ids:
|
||||
continue # mid-turn — don't tear it down
|
||||
last_activity = getattr(agent, "_last_activity_ts", None)
|
||||
if last_activity is None:
|
||||
continue
|
||||
if (now - last_activity) > _AGENT_CACHE_IDLE_TTL_SECS:
|
||||
to_evict.append((key, agent))
|
||||
for key, _ in to_evict:
|
||||
_cache.pop(key, None)
|
||||
for key, agent in to_evict:
|
||||
logger.info(
|
||||
"Agent cache idle-TTL evict: session=%s (idle=%.0fs)",
|
||||
key, now - getattr(agent, "_last_activity_ts", now),
|
||||
)
|
||||
threading.Thread(
|
||||
target=self._release_evicted_agent_soft,
|
||||
args=(agent,),
|
||||
daemon=True,
|
||||
name=f"agent-cache-idle-{key[:24]}",
|
||||
).start()
|
||||
return len(to_evict)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Proxy mode: forward messages to a remote Hermes API server
|
||||
# ------------------------------------------------------------------
|
||||
@@ -8834,13 +8618,6 @@ class GatewayRunner:
|
||||
cached = _cache.get(session_key)
|
||||
if cached and cached[1] == _sig:
|
||||
agent = cached[0]
|
||||
# Refresh LRU order so the cap enforcement evicts
|
||||
# truly-oldest entries, not the one we just used.
|
||||
if hasattr(_cache, "move_to_end"):
|
||||
try:
|
||||
_cache.move_to_end(session_key)
|
||||
except KeyError:
|
||||
pass
|
||||
# Reset activity timestamp so the inactivity timeout
|
||||
# handler doesn't see stale idle time from the previous
|
||||
# turn and immediately kill this agent. (#9051)
|
||||
@@ -8879,7 +8656,6 @@ class GatewayRunner:
|
||||
if _cache_lock and _cache is not None:
|
||||
with _cache_lock:
|
||||
_cache[session_key] = (agent, _sig)
|
||||
self._enforce_agent_cache_cap()
|
||||
logger.debug("Created new agent for session %s (sig=%s)", session_key, _sig)
|
||||
|
||||
# Per-message state — callbacks and reasoning config change every
|
||||
|
||||
@@ -82,7 +82,6 @@ class SessionSource:
|
||||
chat_topic: Optional[str] = None # Channel topic/description (Discord, Slack)
|
||||
user_id_alt: Optional[str] = None # Signal UUID (alternative to phone number)
|
||||
chat_id_alt: Optional[str] = None # Signal group internal ID
|
||||
is_bot: bool = False # True when the message author is a bot/webhook (Discord)
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
|
||||
@@ -773,28 +773,6 @@ def is_source_suppressed(provider_id: str, source: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def unsuppress_credential_source(provider_id: str, source: str) -> bool:
|
||||
"""Clear a suppression marker so the source will be re-seeded on the next load.
|
||||
|
||||
Returns True if a marker was cleared, False if no marker existed.
|
||||
"""
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
suppressed = auth_store.get("suppressed_sources")
|
||||
if not isinstance(suppressed, dict):
|
||||
return False
|
||||
provider_list = suppressed.get(provider_id)
|
||||
if not isinstance(provider_list, list) or source not in provider_list:
|
||||
return False
|
||||
provider_list.remove(source)
|
||||
if not provider_list:
|
||||
suppressed.pop(provider_id, None)
|
||||
if not suppressed:
|
||||
auth_store.pop("suppressed_sources", None)
|
||||
_save_auth_store(auth_store)
|
||||
return True
|
||||
|
||||
|
||||
def get_provider_auth_state(provider_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return persisted auth state for a provider, or None."""
|
||||
auth_store = _load_auth_store()
|
||||
|
||||
@@ -233,9 +233,6 @@ def auth_add_command(args) -> None:
|
||||
return
|
||||
|
||||
if provider == "openai-codex":
|
||||
# Clear any existing suppression marker so a re-link after `hermes auth
|
||||
# remove openai-codex` works without the new tokens being skipped.
|
||||
auth_mod.unsuppress_credential_source(provider, "device_code")
|
||||
creds = auth_mod._codex_device_code_login()
|
||||
label = (getattr(args, "label", None) or "").strip() or label_from_token(
|
||||
creds["tokens"]["access_token"],
|
||||
@@ -355,34 +352,7 @@ def auth_remove_command(args) -> None:
|
||||
# If this was a singleton-seeded credential (OAuth device_code, hermes_pkce),
|
||||
# clear the underlying auth store / credential file so it doesn't get
|
||||
# re-seeded on the next load_pool() call.
|
||||
elif provider == "openai-codex" and (
|
||||
removed.source == "device_code" or removed.source.endswith(":device_code")
|
||||
):
|
||||
# Codex tokens live in TWO places: the Hermes auth store and
|
||||
# ~/.codex/auth.json (the Codex CLI shared file). On every refresh,
|
||||
# refresh_codex_oauth_pure() writes to both. So clearing only the
|
||||
# Hermes auth store is not enough — _seed_from_singletons() will
|
||||
# auto-import from ~/.codex/auth.json on the next load_pool() and
|
||||
# the removal is instantly undone. Mark the source as suppressed
|
||||
# so auto-import is skipped; leave ~/.codex/auth.json untouched so
|
||||
# the Codex CLI itself keeps working.
|
||||
from hermes_cli.auth import (
|
||||
_load_auth_store, _save_auth_store, _auth_store_lock,
|
||||
suppress_credential_source,
|
||||
)
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
providers_dict = auth_store.get("providers")
|
||||
if isinstance(providers_dict, dict) and provider in providers_dict:
|
||||
del providers_dict[provider]
|
||||
_save_auth_store(auth_store)
|
||||
print(f"Cleared {provider} OAuth tokens from auth store")
|
||||
suppress_credential_source(provider, "device_code")
|
||||
print("Suppressed openai-codex device_code source — it will not be re-seeded.")
|
||||
print("Note: Codex CLI credentials still live in ~/.codex/auth.json")
|
||||
print("Run `hermes auth add openai-codex` to re-enable if needed.")
|
||||
|
||||
elif removed.source == "device_code" and provider == "nous":
|
||||
elif removed.source == "device_code" and provider in ("openai-codex", "nous"):
|
||||
from hermes_cli.auth import (
|
||||
_load_auth_store, _save_auth_store, _auth_store_lock,
|
||||
)
|
||||
|
||||
@@ -1,294 +0,0 @@
|
||||
"""
|
||||
DingTalk Device Flow authorization.
|
||||
|
||||
Implements the same 3-step registration flow as dingtalk-openclaw-connector:
|
||||
1. POST /app/registration/init → get nonce
|
||||
2. POST /app/registration/begin → get device_code + verification_uri_complete
|
||||
3. POST /app/registration/poll → poll until SUCCESS → get client_id + client_secret
|
||||
|
||||
The verification_uri_complete is rendered as a QR code in the terminal so the
|
||||
user can scan it with DingTalk to authorize, yielding AppKey + AppSecret
|
||||
automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Configuration ──────────────────────────────────────────────────────────
|
||||
|
||||
REGISTRATION_BASE_URL = os.environ.get(
|
||||
"DINGTALK_REGISTRATION_BASE_URL", "https://oapi.dingtalk.com"
|
||||
).rstrip("/")
|
||||
|
||||
REGISTRATION_SOURCE = os.environ.get("DINGTALK_REGISTRATION_SOURCE", "openClaw")
|
||||
|
||||
|
||||
# ── API helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
class RegistrationError(Exception):
|
||||
"""Raised when a DingTalk registration API call fails."""
|
||||
|
||||
|
||||
def _api_post(path: str, payload: dict) -> dict:
|
||||
"""POST to the registration API and return the parsed JSON body."""
|
||||
url = f"{REGISTRATION_BASE_URL}{path}"
|
||||
try:
|
||||
resp = requests.post(url, json=payload, timeout=15)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except requests.RequestException as exc:
|
||||
raise RegistrationError(f"Network error calling {url}: {exc}") from exc
|
||||
|
||||
errcode = data.get("errcode", -1)
|
||||
if errcode != 0:
|
||||
errmsg = data.get("errmsg", "unknown error")
|
||||
raise RegistrationError(f"API error [{path}]: {errmsg} (errcode={errcode})")
|
||||
return data
|
||||
|
||||
|
||||
# ── Core flow ──────────────────────────────────────────────────────────────
|
||||
|
||||
def begin_registration() -> dict:
|
||||
"""Start a device-flow registration.
|
||||
|
||||
Returns a dict with keys:
|
||||
device_code, verification_uri_complete, expires_in, interval
|
||||
"""
|
||||
# Step 1: init → nonce
|
||||
init_data = _api_post("/app/registration/init", {"source": REGISTRATION_SOURCE})
|
||||
nonce = str(init_data.get("nonce", "")).strip()
|
||||
if not nonce:
|
||||
raise RegistrationError("init response missing nonce")
|
||||
|
||||
# Step 2: begin → device_code, verification_uri_complete
|
||||
begin_data = _api_post("/app/registration/begin", {"nonce": nonce})
|
||||
device_code = str(begin_data.get("device_code", "")).strip()
|
||||
verification_uri_complete = str(begin_data.get("verification_uri_complete", "")).strip()
|
||||
if not device_code:
|
||||
raise RegistrationError("begin response missing device_code")
|
||||
if not verification_uri_complete:
|
||||
raise RegistrationError("begin response missing verification_uri_complete")
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"verification_uri_complete": verification_uri_complete,
|
||||
"expires_in": int(begin_data.get("expires_in", 7200)),
|
||||
"interval": max(int(begin_data.get("interval", 3)), 2),
|
||||
}
|
||||
|
||||
|
||||
def poll_registration(device_code: str) -> dict:
|
||||
"""Poll the registration status once.
|
||||
|
||||
Returns a dict with keys: status, client_id?, client_secret?, fail_reason?
|
||||
"""
|
||||
data = _api_post("/app/registration/poll", {"device_code": device_code})
|
||||
status_raw = str(data.get("status", "")).strip().upper()
|
||||
if status_raw not in ("WAITING", "SUCCESS", "FAIL", "EXPIRED"):
|
||||
status_raw = "UNKNOWN"
|
||||
return {
|
||||
"status": status_raw,
|
||||
"client_id": str(data.get("client_id", "")).strip() or None,
|
||||
"client_secret": str(data.get("client_secret", "")).strip() or None,
|
||||
"fail_reason": str(data.get("fail_reason", "")).strip() or None,
|
||||
}
|
||||
|
||||
|
||||
def wait_for_registration_success(
|
||||
device_code: str,
|
||||
interval: int = 3,
|
||||
expires_in: int = 7200,
|
||||
on_waiting: Optional[callable] = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""Block until the registration succeeds or times out.
|
||||
|
||||
Returns (client_id, client_secret).
|
||||
"""
|
||||
deadline = time.monotonic() + expires_in
|
||||
retry_window = 120 # 2 minutes for transient errors
|
||||
retry_start = 0.0
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
time.sleep(interval)
|
||||
try:
|
||||
result = poll_registration(device_code)
|
||||
except RegistrationError:
|
||||
if retry_start == 0:
|
||||
retry_start = time.monotonic()
|
||||
if time.monotonic() - retry_start < retry_window:
|
||||
continue
|
||||
raise
|
||||
|
||||
status = result["status"]
|
||||
if status == "WAITING":
|
||||
retry_start = 0
|
||||
if on_waiting:
|
||||
on_waiting()
|
||||
continue
|
||||
if status == "SUCCESS":
|
||||
cid = result["client_id"]
|
||||
csecret = result["client_secret"]
|
||||
if not cid or not csecret:
|
||||
raise RegistrationError("authorization succeeded but credentials are missing")
|
||||
return cid, csecret
|
||||
# FAIL / EXPIRED / UNKNOWN
|
||||
if retry_start == 0:
|
||||
retry_start = time.monotonic()
|
||||
if time.monotonic() - retry_start < retry_window:
|
||||
continue
|
||||
reason = result.get("fail_reason") or status
|
||||
raise RegistrationError(f"authorization failed: {reason}")
|
||||
|
||||
raise RegistrationError("authorization timed out, please retry")
|
||||
|
||||
|
||||
# ── QR code rendering ─────────────────────────────────────────────────────
|
||||
|
||||
def _ensure_qrcode_installed() -> bool:
|
||||
"""Try to import qrcode; if missing, auto-install it via pip/uv."""
|
||||
try:
|
||||
import qrcode # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import subprocess
|
||||
|
||||
# Try uv first (Hermes convention), then pip
|
||||
for cmd in (
|
||||
[sys.executable, "-m", "uv", "pip", "install", "qrcode"],
|
||||
[sys.executable, "-m", "pip", "install", "-q", "qrcode"],
|
||||
):
|
||||
try:
|
||||
subprocess.check_call(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
import qrcode # noqa: F401,F811
|
||||
return True
|
||||
except (subprocess.CalledProcessError, ImportError, FileNotFoundError):
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
def render_qr_to_terminal(url: str) -> bool:
|
||||
"""Render *url* as a compact QR code in the terminal.
|
||||
|
||||
Returns True if the QR code was printed, False if the library is missing.
|
||||
"""
|
||||
try:
|
||||
import qrcode
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
qr = qrcode.QRCode(
|
||||
version=1,
|
||||
error_correction=qrcode.constants.ERROR_CORRECT_L,
|
||||
box_size=1,
|
||||
border=1,
|
||||
)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
|
||||
# Use half-block characters for compact rendering (2 rows per character)
|
||||
matrix = qr.get_matrix()
|
||||
rows = len(matrix)
|
||||
lines: list[str] = []
|
||||
|
||||
TOP_HALF = "\u2580" # ▀
|
||||
BOTTOM_HALF = "\u2584" # ▄
|
||||
FULL_BLOCK = "\u2588" # █
|
||||
EMPTY = " "
|
||||
|
||||
for r in range(0, rows, 2):
|
||||
line_chars: list[str] = []
|
||||
for c in range(len(matrix[r])):
|
||||
top = matrix[r][c]
|
||||
bottom = matrix[r + 1][c] if r + 1 < rows else False
|
||||
if top and bottom:
|
||||
line_chars.append(FULL_BLOCK)
|
||||
elif top:
|
||||
line_chars.append(TOP_HALF)
|
||||
elif bottom:
|
||||
line_chars.append(BOTTOM_HALF)
|
||||
else:
|
||||
line_chars.append(EMPTY)
|
||||
lines.append(" " + "".join(line_chars))
|
||||
|
||||
print("\n".join(lines))
|
||||
return True
|
||||
|
||||
|
||||
# ── High-level entry point for the setup wizard ───────────────────────────
|
||||
|
||||
def dingtalk_qr_auth() -> Optional[Tuple[str, str]]:
|
||||
"""Run the interactive QR-code device-flow authorization.
|
||||
|
||||
Returns (client_id, client_secret) on success, or None if the user
|
||||
cancelled or the flow failed.
|
||||
"""
|
||||
from hermes_cli.setup import print_info, print_success, print_warning, print_error
|
||||
|
||||
print()
|
||||
print_info(" Initializing DingTalk device authorization...")
|
||||
print_info(" Note: the scan page is branded 'OpenClaw' — DingTalk's")
|
||||
print_info(" ecosystem onboarding bridge. Safe to use.")
|
||||
|
||||
try:
|
||||
reg = begin_registration()
|
||||
except RegistrationError as exc:
|
||||
print_error(f" Authorization init failed: {exc}")
|
||||
return None
|
||||
|
||||
url = reg["verification_uri_complete"]
|
||||
|
||||
# Ensure qrcode library is available (auto-install if missing)
|
||||
if not _ensure_qrcode_installed():
|
||||
print_warning(" qrcode library install failed, will show link only.")
|
||||
|
||||
print()
|
||||
print_info(" Please scan the QR code below with DingTalk to authorize:")
|
||||
print()
|
||||
|
||||
if not render_qr_to_terminal(url):
|
||||
print_warning(f" QR code render failed, please open the link below to authorize:")
|
||||
|
||||
print()
|
||||
print_info(f" Or open this link manually: {url}")
|
||||
print()
|
||||
print_info(" Waiting for QR scan authorization... (timeout: 2 hours)")
|
||||
|
||||
dot_count = 0
|
||||
|
||||
def _on_waiting():
|
||||
nonlocal dot_count
|
||||
dot_count += 1
|
||||
if dot_count % 10 == 0:
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
|
||||
try:
|
||||
client_id, client_secret = wait_for_registration_success(
|
||||
device_code=reg["device_code"],
|
||||
interval=reg["interval"],
|
||||
expires_in=reg["expires_in"],
|
||||
on_waiting=_on_waiting,
|
||||
)
|
||||
except RegistrationError as exc:
|
||||
print()
|
||||
print_error(f" Authorization failed: {exc}")
|
||||
return None
|
||||
|
||||
print()
|
||||
print_success(" QR scan authorization successful!")
|
||||
print_success(f" Client ID: {client_id}")
|
||||
print_success(f" Client Secret: {client_secret[:8]}{'*' * (len(client_secret) - 8)}")
|
||||
|
||||
return client_id, client_secret
|
||||
+2
-57
@@ -2211,62 +2211,9 @@ def _setup_sms():
|
||||
|
||||
|
||||
def _setup_dingtalk():
|
||||
"""Configure DingTalk — QR scan (recommended) or manual credential entry."""
|
||||
from hermes_cli.setup import (
|
||||
prompt_choice, prompt_yes_no, print_info, print_success, print_warning,
|
||||
)
|
||||
|
||||
"""Configure DingTalk via the standard platform setup."""
|
||||
dingtalk_platform = next(p for p in _PLATFORMS if p["key"] == "dingtalk")
|
||||
emoji = dingtalk_platform["emoji"]
|
||||
label = dingtalk_platform["label"]
|
||||
|
||||
print()
|
||||
print(color(f" ─── {emoji} {label} Setup ───", Colors.CYAN))
|
||||
|
||||
existing = get_env_value("DINGTALK_CLIENT_ID")
|
||||
if existing:
|
||||
print()
|
||||
print_success(f"{label} is already configured (Client ID: {existing}).")
|
||||
if not prompt_yes_no(f" Reconfigure {label}?", False):
|
||||
return
|
||||
|
||||
print()
|
||||
method = prompt_choice(
|
||||
" Choose setup method",
|
||||
[
|
||||
"QR Code Scan (Recommended, auto-obtain Client ID and Client Secret)",
|
||||
"Manual Input (Client ID and Client Secret)",
|
||||
],
|
||||
default=0,
|
||||
)
|
||||
|
||||
if method == 0:
|
||||
# ── QR-code device-flow authorization ──
|
||||
try:
|
||||
from hermes_cli.dingtalk_auth import dingtalk_qr_auth
|
||||
except ImportError as exc:
|
||||
print_warning(f" QR auth module failed to load ({exc}), falling back to manual input.")
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
return
|
||||
|
||||
result = dingtalk_qr_auth()
|
||||
if result is None:
|
||||
print_warning(" QR auth incomplete, falling back to manual input.")
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
return
|
||||
|
||||
client_id, client_secret = result
|
||||
save_env_value("DINGTALK_CLIENT_ID", client_id)
|
||||
save_env_value("DINGTALK_CLIENT_SECRET", client_secret)
|
||||
save_env_value("DINGTALK_ALLOW_ALL_USERS", "true")
|
||||
print()
|
||||
print_success(f"{emoji} {label} configured via QR scan!")
|
||||
else:
|
||||
# ── Manual entry ──
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
# Also enable allow-all by default for convenience
|
||||
if get_env_value("DINGTALK_CLIENT_ID"):
|
||||
save_env_value("DINGTALK_ALLOW_ALL_USERS", "true")
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
|
||||
|
||||
def _setup_wecom():
|
||||
@@ -2802,8 +2749,6 @@ def gateway_setup():
|
||||
_setup_signal()
|
||||
elif platform["key"] == "weixin":
|
||||
_setup_weixin()
|
||||
elif platform["key"] == "dingtalk":
|
||||
_setup_dingtalk()
|
||||
elif platform["key"] == "feishu":
|
||||
_setup_feishu()
|
||||
else:
|
||||
|
||||
+2
-2
@@ -2472,10 +2472,10 @@ def _model_flow_kimi(config, current_model=""):
|
||||
|
||||
# Step 3: Model selection — show appropriate models for the endpoint
|
||||
if is_coding_plan:
|
||||
# Coding Plan models (kimi-k2.5 first)
|
||||
# Coding Plan models (kimi-for-coding first)
|
||||
model_list = [
|
||||
"kimi-k2.5",
|
||||
"kimi-for-coding",
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-thinking-turbo",
|
||||
]
|
||||
|
||||
@@ -374,26 +374,7 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str:
|
||||
return bare
|
||||
return _dots_to_hyphens(bare)
|
||||
|
||||
# --- Copilot / Copilot ACP: delegate to the Copilot-specific
|
||||
# normalizer. It knows about the alias table (vendor-prefix
|
||||
# stripping for Anthropic/OpenAI, dash-to-dot repair for Claude)
|
||||
# and live-catalog lookups. Without this, vendor-prefixed or
|
||||
# dash-notation Claude IDs survive to the Copilot API and hit
|
||||
# HTTP 400 "model_not_supported". See issue #6879.
|
||||
if provider in {"copilot", "copilot-acp"}:
|
||||
try:
|
||||
from hermes_cli.models import normalize_copilot_model_id
|
||||
|
||||
normalized = normalize_copilot_model_id(name)
|
||||
if normalized:
|
||||
return normalized
|
||||
except Exception:
|
||||
# Fall through to the generic strip-vendor behaviour below
|
||||
# if the Copilot-specific path is unavailable for any reason.
|
||||
pass
|
||||
|
||||
# --- Copilot / Copilot ACP / openai-codex fallback:
|
||||
# strip matching provider prefix, keep dots ---
|
||||
# --- Copilot: strip matching provider prefix, keep dots ---
|
||||
if provider in _STRIP_VENDOR_ONLY_PROVIDERS:
|
||||
stripped = _strip_matching_provider_prefix(name, provider)
|
||||
if stripped == name and name.startswith("openai/"):
|
||||
|
||||
+8
-21
@@ -26,8 +26,7 @@ COPILOT_REASONING_EFFORTS_O_SERIES = ["low", "medium", "high"]
|
||||
# Fallback OpenRouter snapshot used when the live catalog is unavailable.
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("moonshotai/kimi-k2.5", "recommended"),
|
||||
("anthropic/claude-opus-4.7", ""),
|
||||
("anthropic/claude-opus-4.7", "recommended"),
|
||||
("anthropic/claude-opus-4.6", ""),
|
||||
("anthropic/claude-sonnet-4.6", ""),
|
||||
("qwen/qwen3.6-plus", ""),
|
||||
@@ -50,6 +49,7 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("z-ai/glm-5.1", ""),
|
||||
("z-ai/glm-5v-turbo", ""),
|
||||
("z-ai/glm-5-turbo", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("x-ai/grok-4.20", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b:free", "free"),
|
||||
@@ -75,7 +75,6 @@ def _codex_curated_models() -> list[str]:
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"nous": [
|
||||
"moonshotai/kimi-k2.5",
|
||||
"xiaomi/mimo-v2-pro",
|
||||
"anthropic/claude-opus-4.7",
|
||||
"anthropic/claude-opus-4.6",
|
||||
@@ -97,6 +96,7 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"z-ai/glm-5.1",
|
||||
"z-ai/glm-5v-turbo",
|
||||
"z-ai/glm-5-turbo",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"x-ai/grok-4.20-beta",
|
||||
"nvidia/nemotron-3-super-120b-a12b",
|
||||
"nvidia/nemotron-3-super-120b-a12b:free",
|
||||
@@ -156,8 +156,8 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"grok-4-1-fast-reasoning",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-k2.5",
|
||||
"kimi-for-coding",
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-thinking-turbo",
|
||||
"kimi-k2-turbo-preview",
|
||||
@@ -212,7 +212,6 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"trinity-mini",
|
||||
],
|
||||
"opencode-zen": [
|
||||
"kimi-k2.5",
|
||||
"gpt-5.4-pro",
|
||||
"gpt-5.4",
|
||||
"gpt-5.3-codex",
|
||||
@@ -244,15 +243,16 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.6",
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2",
|
||||
"qwen3-coder",
|
||||
"big-pickle",
|
||||
],
|
||||
"opencode-go": [
|
||||
"kimi-k2.5",
|
||||
"glm-5.1",
|
||||
"glm-5",
|
||||
"kimi-k2.5",
|
||||
"mimo-v2-pro",
|
||||
"mimo-v2-omni",
|
||||
"minimax-m2.7",
|
||||
@@ -285,21 +285,21 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
# to https://dashscope-intl.aliyuncs.com/compatible-mode/v1 (OpenAI-compat)
|
||||
# or https://dashscope-intl.aliyuncs.com/apps/anthropic (Anthropic-compat).
|
||||
"alibaba": [
|
||||
"kimi-k2.5",
|
||||
"qwen3.5-plus",
|
||||
"qwen3-coder-plus",
|
||||
"qwen3-coder-next",
|
||||
# Third-party models available on coding-intl
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"kimi-k2.5",
|
||||
"MiniMax-M2.5",
|
||||
],
|
||||
# Curated HF model list — only agentic models that map to OpenRouter defaults.
|
||||
"huggingface": [
|
||||
"moonshotai/Kimi-K2.5",
|
||||
"Qwen/Qwen3.5-397B-A17B",
|
||||
"Qwen/Qwen3.5-35B-A3B",
|
||||
"deepseek-ai/DeepSeek-V3.2",
|
||||
"moonshotai/Kimi-K2.5",
|
||||
"MiniMaxAI/MiniMax-M2.5",
|
||||
"zai-org/GLM-5",
|
||||
"XiaomiMiMo/MiMo-V2-Flash",
|
||||
@@ -1488,19 +1488,6 @@ _COPILOT_MODEL_ALIASES = {
|
||||
"anthropic/claude-sonnet-4.6": "claude-sonnet-4.6",
|
||||
"anthropic/claude-sonnet-4.5": "claude-sonnet-4.5",
|
||||
"anthropic/claude-haiku-4.5": "claude-haiku-4.5",
|
||||
# Dash-notation fallbacks: Hermes' default Claude IDs elsewhere use
|
||||
# hyphens (anthropic native format), but Copilot's API only accepts
|
||||
# dot-notation. Accept both so users who configure copilot + a
|
||||
# default hyphenated Claude model don't hit HTTP 400
|
||||
# "model_not_supported". See issue #6879.
|
||||
"claude-opus-4-6": "claude-opus-4.6",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"claude-haiku-4-5": "claude-haiku-4.5",
|
||||
"anthropic/claude-opus-4-6": "claude-opus-4.6",
|
||||
"anthropic/claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"anthropic/claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"anthropic/claude-haiku-4-5": "claude-haiku-4.5",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -512,7 +512,7 @@ def _get_platform_tools(
|
||||
"""Resolve which individual toolset names are enabled for a platform."""
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
platform_toolsets = config.get("platform_toolsets") or {}
|
||||
platform_toolsets = config.get("platform_toolsets", {})
|
||||
toolset_names = platform_toolsets.get(platform)
|
||||
|
||||
if toolset_names is None or not isinstance(toolset_names, list):
|
||||
|
||||
+2
-2
@@ -39,8 +39,8 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
modal = ["modal>=1.0.0,<2"]
|
||||
daytona = ["daytona>=0.148.0,<1"]
|
||||
dev = ["debugpy>=1.8.0,<2", "pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", "pytest-xdist>=3.0,<4", "mcp>=1.2.0,<2"]
|
||||
messaging = ["python-telegram-bot[webhooks]>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4", "qrcode>=7.0,<8"]
|
||||
dev = ["debugpy>=1.8.0,<2", "pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", "pytest-xdist>=3.0,<4", "pytest-split>=0.9,<1", "mcp>=1.2.0,<2"]
|
||||
messaging = ["python-telegram-bot[webhooks]>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"]
|
||||
cron = ["croniter>=6.0.0,<7"]
|
||||
slack = ["slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"]
|
||||
matrix = ["mautrix[encryption]>=0.20,<1", "Markdown>=3.6,<4", "aiosqlite>=0.20", "asyncpg>=0.29"]
|
||||
|
||||
@@ -3242,53 +3242,6 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def release_clients(self) -> None:
|
||||
"""Release LLM client resources WITHOUT tearing down session tool state.
|
||||
|
||||
Used by the gateway when evicting this agent from _agent_cache for
|
||||
memory-management reasons (LRU cap or idle TTL) — the session may
|
||||
resume at any time with a freshly-built AIAgent that reuses the
|
||||
same task_id / session_id, so we must NOT kill:
|
||||
- process_registry entries for task_id (user's bg shells)
|
||||
- terminal sandbox for task_id (cwd, env, shell state)
|
||||
- browser daemon for task_id (open tabs, cookies)
|
||||
- memory provider (has its own lifecycle; keeps running)
|
||||
|
||||
We DO close:
|
||||
- OpenAI/httpx client pool (big chunk of held memory + sockets;
|
||||
the rebuilt agent gets a fresh client anyway)
|
||||
- Active child subagents (per-turn artefacts; safe to drop)
|
||||
|
||||
Safe to call multiple times. Distinct from close() — which is the
|
||||
hard teardown for actual session boundaries (/new, /reset, session
|
||||
expiry).
|
||||
"""
|
||||
# Close active child agents (per-turn; no cross-turn persistence).
|
||||
try:
|
||||
with self._active_children_lock:
|
||||
children = list(self._active_children)
|
||||
self._active_children.clear()
|
||||
for child in children:
|
||||
try:
|
||||
child.release_clients()
|
||||
except Exception:
|
||||
# Fall back to full close on children; they're per-turn.
|
||||
try:
|
||||
child.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Close the OpenAI/httpx client to release sockets immediately.
|
||||
try:
|
||||
client = getattr(self, "client", None)
|
||||
if client is not None:
|
||||
self._close_openai_client(client, reason="cache_evict", shared=True)
|
||||
self.client = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release all resources held by this agent instance.
|
||||
|
||||
|
||||
@@ -50,7 +50,6 @@ AUTHOR_MAP = {
|
||||
"16443023+stablegenius49@users.noreply.github.com": "stablegenius49",
|
||||
"185121704+stablegenius49@users.noreply.github.com": "stablegenius49",
|
||||
"101283333+batuhankocyigit@users.noreply.github.com": "batuhankocyigit",
|
||||
"valdi.jorge@gmail.com": "jvcl",
|
||||
"126368201+vilkasdev@users.noreply.github.com": "vilkasdev",
|
||||
"137614867+cutepawss@users.noreply.github.com": "cutepawss",
|
||||
"96793918+memosr@users.noreply.github.com": "memosr",
|
||||
@@ -71,8 +70,6 @@ AUTHOR_MAP = {
|
||||
"27917469+nosleepcassette@users.noreply.github.com": "nosleepcassette",
|
||||
"241404605+MestreY0d4-Uninter@users.noreply.github.com": "MestreY0d4-Uninter",
|
||||
"109555139+davetist@users.noreply.github.com": "davetist",
|
||||
"39405770+yyq4193@users.noreply.github.com": "yyq4193",
|
||||
"Asunfly@users.noreply.github.com": "Asunfly",
|
||||
# contributors (manual mapping from git names)
|
||||
"ahmedsherif95@gmail.com": "asheriif",
|
||||
"dmayhem93@gmail.com": "dmahan93",
|
||||
@@ -86,8 +83,6 @@ AUTHOR_MAP = {
|
||||
"abdullahfarukozden@gmail.com": "Farukest",
|
||||
"lovre.pesut@gmail.com": "rovle",
|
||||
"kevinskysunny@gmail.com": "kevinskysunny",
|
||||
"xiewenxuan462@gmail.com": "yule975",
|
||||
"yiweimeng.dlut@hotmail.com": "meng93",
|
||||
"hakanerten02@hotmail.com": "teyrebaz33",
|
||||
"ruzzgarcn@gmail.com": "Ruzzgar",
|
||||
"alireza78.crypto@gmail.com": "alireza78a",
|
||||
@@ -100,7 +95,6 @@ AUTHOR_MAP = {
|
||||
"mcosma@gmail.com": "wakamex",
|
||||
"clawdia.nash@proton.me": "clawdia-nash",
|
||||
"pickett.austin@gmail.com": "austinpickett",
|
||||
"dangtc94@gmail.com": "dieutx",
|
||||
"jaisehgal11299@gmail.com": "jaisup",
|
||||
"percydikec@gmail.com": "PercyDikec",
|
||||
"dean.kerr@gmail.com": "deankerr",
|
||||
@@ -183,7 +177,6 @@ AUTHOR_MAP = {
|
||||
"juan.ovalle@mistral.ai": "jjovalle99",
|
||||
"julien.talbot@ergonomia.re": "Julientalbot",
|
||||
"kagura.chen28@gmail.com": "kagura-agent",
|
||||
"1342088860@qq.com": "youngDoo",
|
||||
"kamil@gwozdz.me": "kamil-gwozdz",
|
||||
"karamusti912@gmail.com": "MustafaKara7",
|
||||
"kira@ariaki.me": "kira-ariaki",
|
||||
@@ -238,23 +231,7 @@ AUTHOR_MAP = {
|
||||
"zaynjarvis@gmail.com": "ZaynJarvis",
|
||||
"zhiheng.liu@bytedance.com": "ZaynJarvis",
|
||||
"mbelleau@Michels-MacBook-Pro.local": "malaiwah",
|
||||
"michel.belleau@malaiwah.com": "malaiwah",
|
||||
"gnanasekaran.sekareee@gmail.com": "gnanam1990",
|
||||
"jz.pentest@gmail.com": "0xyg3n",
|
||||
"hypnosis.mda@gmail.com": "Hypn0sis",
|
||||
"ywt000818@gmail.com": "OwenYWT",
|
||||
"dhandhalyabhavik@gmail.com": "v1k22",
|
||||
"rucchizhao@zhaochenfeideMacBook-Pro.local": "RucchiZ",
|
||||
"lehaolin98@outlook.com": "LehaoLin",
|
||||
"yuewang1@microsoft.com": "imink",
|
||||
"1736355688@qq.com": "hedgeho9X",
|
||||
"bernylinville@devopsthink.org": "bernylinville",
|
||||
"brian@bde.io": "briandevans",
|
||||
"hubin_ll@qq.com": "LLQWQ",
|
||||
"memosr_email@gmail.com": "memosr",
|
||||
"anthhub@163.com": "anthhub",
|
||||
"shenuu@gmail.com": "shenuu",
|
||||
"xiayh17@gmail.com": "xiayh0107",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Canonical test runner for hermes-agent. Run this instead of calling
|
||||
# `pytest` directly to guarantee your local run matches CI behavior.
|
||||
#
|
||||
# What this script enforces:
|
||||
# * -n 4 xdist workers (CI has 4 cores; -n auto diverges locally)
|
||||
# * TZ=UTC, LANG=C.UTF-8, PYTHONHASHSEED=0 (deterministic)
|
||||
# * Credential env vars blanked (conftest.py also does this, but this
|
||||
# is belt-and-suspenders for anyone running `pytest` outside of
|
||||
# our conftest path — e.g. calling pytest on a single file)
|
||||
# * Proper venv activation
|
||||
#
|
||||
# Usage:
|
||||
# scripts/run_tests.sh # full suite
|
||||
# scripts/run_tests.sh tests/agent/ # one directory
|
||||
# scripts/run_tests.sh tests/agent/test_foo.py::TestClass::test_method
|
||||
# scripts/run_tests.sh --tb=long -v # pass-through pytest args
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ── Locate repo root ────────────────────────────────────────────────────────
|
||||
# Works whether this is the main checkout or a worktree.
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
|
||||
# ── Activate venv ───────────────────────────────────────────────────────────
|
||||
# Prefer a .venv in the current tree, fall back to the main checkout's venv
|
||||
# (useful for worktrees where we don't always duplicate the venv).
|
||||
VENV=""
|
||||
for candidate in "$REPO_ROOT/.venv" "$REPO_ROOT/venv" "$HOME/.hermes/hermes-agent/venv"; do
|
||||
if [ -f "$candidate/bin/activate" ]; then
|
||||
VENV="$candidate"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$VENV" ]; then
|
||||
echo "error: no virtualenv found in $REPO_ROOT/.venv or $REPO_ROOT/venv" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PYTHON="$VENV/bin/python"
|
||||
|
||||
# ── Ensure pytest-split is installed (required for shard-equivalent runs) ──
|
||||
if ! "$PYTHON" -c "import pytest_split" 2>/dev/null; then
|
||||
echo "→ installing pytest-split into $VENV"
|
||||
"$PYTHON" -m pip install --quiet "pytest-split>=0.9,<1"
|
||||
fi
|
||||
|
||||
# ── Hermetic environment ────────────────────────────────────────────────────
|
||||
# Mirror what CI does in .github/workflows/tests.yml + what conftest.py does.
|
||||
# Unset every credential-shaped var currently in the environment.
|
||||
while IFS='=' read -r name _; do
|
||||
case "$name" in
|
||||
*_API_KEY|*_TOKEN|*_SECRET|*_PASSWORD|*_CREDENTIALS|*_ACCESS_KEY| \
|
||||
*_SECRET_ACCESS_KEY|*_PRIVATE_KEY|*_OAUTH_TOKEN|*_WEBHOOK_SECRET| \
|
||||
*_ENCRYPT_KEY|*_APP_SECRET|*_CLIENT_SECRET|*_CORP_SECRET|*_AES_KEY| \
|
||||
AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_SESSION_TOKEN|FAL_KEY| \
|
||||
GH_TOKEN|GITHUB_TOKEN)
|
||||
unset "$name"
|
||||
;;
|
||||
esac
|
||||
done < <(env)
|
||||
|
||||
# Unset HERMES_* behavioral vars too.
|
||||
unset HERMES_YOLO_MODE HERMES_INTERACTIVE HERMES_QUIET HERMES_TOOL_PROGRESS \
|
||||
HERMES_TOOL_PROGRESS_MODE HERMES_MAX_ITERATIONS HERMES_SESSION_PLATFORM \
|
||||
HERMES_SESSION_CHAT_ID HERMES_SESSION_CHAT_NAME HERMES_SESSION_THREAD_ID \
|
||||
HERMES_SESSION_SOURCE HERMES_SESSION_KEY HERMES_GATEWAY_SESSION \
|
||||
HERMES_PLATFORM HERMES_INFERENCE_PROVIDER HERMES_MANAGED HERMES_DEV \
|
||||
HERMES_CONTAINER HERMES_EPHEMERAL_SYSTEM_PROMPT HERMES_TIMEZONE \
|
||||
HERMES_REDACT_SECRETS HERMES_BACKGROUND_NOTIFICATIONS HERMES_EXEC_ASK \
|
||||
HERMES_HOME_MODE 2>/dev/null || true
|
||||
|
||||
# Pin deterministic runtime.
|
||||
export TZ=UTC
|
||||
export LANG=C.UTF-8
|
||||
export LC_ALL=C.UTF-8
|
||||
export PYTHONHASHSEED=0
|
||||
|
||||
# ── Worker count ────────────────────────────────────────────────────────────
|
||||
# CI uses `-n auto` on ubuntu-latest which gives 4 workers. A 20-core
|
||||
# workstation with `-n auto` gets 20 workers and exposes test-ordering
|
||||
# flakes that CI will never see. Pin to 4 so local matches CI.
|
||||
WORKERS="${HERMES_TEST_WORKERS:-4}"
|
||||
|
||||
# ── Run pytest ──────────────────────────────────────────────────────────────
|
||||
cd "$REPO_ROOT"
|
||||
|
||||
# If the first argument starts with `-` treat all args as pytest flags;
|
||||
# otherwise treat them as test paths.
|
||||
ARGS=("$@")
|
||||
|
||||
echo "▶ running pytest with $WORKERS workers, hermetic env, in $REPO_ROOT"
|
||||
echo " (TZ=UTC LANG=C.UTF-8 PYTHONHASHSEED=0; all credential env vars unset)"
|
||||
|
||||
# -o "addopts=" clears pyproject.toml's `-n auto` so our -n wins.
|
||||
exec "$PYTHON" -m pytest \
|
||||
-o "addopts=" \
|
||||
-n "$WORKERS" \
|
||||
--ignore=tests/integration \
|
||||
--ignore=tests/e2e \
|
||||
-m "not integration" \
|
||||
"${ARGS[@]}"
|
||||
+19
-224
@@ -1,27 +1,7 @@
|
||||
"""Shared fixtures for the hermes-agent test suite.
|
||||
|
||||
Hermetic-test invariants enforced here (see AGENTS.md for rationale):
|
||||
|
||||
1. **No credential env vars.** All provider/credential-shaped env vars
|
||||
(ending in _API_KEY, _TOKEN, _SECRET, _PASSWORD, _CREDENTIALS, etc.)
|
||||
are unset before every test. Local developer keys cannot leak in.
|
||||
2. **Isolated HERMES_HOME.** HERMES_HOME points to a per-test tempdir so
|
||||
code reading ``~/.hermes/*`` via ``get_hermes_home()`` can't see the
|
||||
real one. (We do NOT also redirect HOME — that broke subprocesses in
|
||||
CI. Code using ``Path.home() / ".hermes"`` instead of the canonical
|
||||
``get_hermes_home()`` is a bug to fix at the callsite.)
|
||||
3. **Deterministic runtime.** TZ=UTC, LANG=C.UTF-8, PYTHONHASHSEED=0.
|
||||
4. **No HERMES_SESSION_* inheritance** — the agent's current gateway
|
||||
session must not leak into tests.
|
||||
|
||||
These invariants make the local test run match CI closely. Gaps that
|
||||
remain (CPU count, xdist worker count) are addressed by the canonical
|
||||
test runner at ``scripts/run_tests.sh``.
|
||||
"""
|
||||
"""Shared fixtures for the hermes-agent test suite."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -36,215 +16,30 @@ if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
# ── Credential env-var filter ──────────────────────────────────────────────
|
||||
#
|
||||
# Any env var in the current process matching ONE of these patterns is
|
||||
# unset for every test. Developers' local keys cannot leak into assertions
|
||||
# about "auto-detect provider when key present".
|
||||
|
||||
_CREDENTIAL_SUFFIXES = (
|
||||
"_API_KEY",
|
||||
"_TOKEN",
|
||||
"_SECRET",
|
||||
"_PASSWORD",
|
||||
"_CREDENTIALS",
|
||||
"_ACCESS_KEY",
|
||||
"_SECRET_ACCESS_KEY",
|
||||
"_PRIVATE_KEY",
|
||||
"_OAUTH_TOKEN",
|
||||
"_WEBHOOK_SECRET",
|
||||
"_ENCRYPT_KEY",
|
||||
"_APP_SECRET",
|
||||
"_CLIENT_SECRET",
|
||||
"_CORP_SECRET",
|
||||
"_AES_KEY",
|
||||
)
|
||||
|
||||
# Explicit names (for ones that don't fit the suffix pattern)
|
||||
_CREDENTIAL_NAMES = frozenset({
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_SESSION_TOKEN",
|
||||
"ANTHROPIC_TOKEN",
|
||||
"FAL_KEY",
|
||||
"GH_TOKEN",
|
||||
"GITHUB_TOKEN",
|
||||
"OPENAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"NOUS_API_KEY",
|
||||
"GEMINI_API_KEY",
|
||||
"GOOGLE_API_KEY",
|
||||
"GROQ_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"MISTRAL_API_KEY",
|
||||
"DEEPSEEK_API_KEY",
|
||||
"KIMI_API_KEY",
|
||||
"MOONSHOT_API_KEY",
|
||||
"GLM_API_KEY",
|
||||
"ZAI_API_KEY",
|
||||
"MINIMAX_API_KEY",
|
||||
"OLLAMA_API_KEY",
|
||||
"OPENVIKING_API_KEY",
|
||||
"COPILOT_API_KEY",
|
||||
"CLAUDE_CODE_OAUTH_TOKEN",
|
||||
"BROWSERBASE_API_KEY",
|
||||
"FIRECRAWL_API_KEY",
|
||||
"PARALLEL_API_KEY",
|
||||
"EXA_API_KEY",
|
||||
"TAVILY_API_KEY",
|
||||
"WANDB_API_KEY",
|
||||
"ELEVENLABS_API_KEY",
|
||||
"HONCHO_API_KEY",
|
||||
"MEM0_API_KEY",
|
||||
"SUPERMEMORY_API_KEY",
|
||||
"RETAINDB_API_KEY",
|
||||
"HINDSIGHT_API_KEY",
|
||||
"HINDSIGHT_LLM_API_KEY",
|
||||
"TINKER_API_KEY",
|
||||
"DAYTONA_API_KEY",
|
||||
"TWILIO_AUTH_TOKEN",
|
||||
"TELEGRAM_BOT_TOKEN",
|
||||
"DISCORD_BOT_TOKEN",
|
||||
"SLACK_BOT_TOKEN",
|
||||
"SLACK_APP_TOKEN",
|
||||
"MATTERMOST_TOKEN",
|
||||
"MATRIX_ACCESS_TOKEN",
|
||||
"MATRIX_PASSWORD",
|
||||
"MATRIX_RECOVERY_KEY",
|
||||
"HASS_TOKEN",
|
||||
"EMAIL_PASSWORD",
|
||||
"BLUEBUBBLES_PASSWORD",
|
||||
"FEISHU_APP_SECRET",
|
||||
"FEISHU_ENCRYPT_KEY",
|
||||
"FEISHU_VERIFICATION_TOKEN",
|
||||
"DINGTALK_CLIENT_SECRET",
|
||||
"QQ_CLIENT_SECRET",
|
||||
"QQ_STT_API_KEY",
|
||||
"WECOM_SECRET",
|
||||
"WECOM_CALLBACK_CORP_SECRET",
|
||||
"WECOM_CALLBACK_TOKEN",
|
||||
"WECOM_CALLBACK_ENCODING_AES_KEY",
|
||||
"WEIXIN_TOKEN",
|
||||
"MODAL_TOKEN_ID",
|
||||
"MODAL_TOKEN_SECRET",
|
||||
"TERMINAL_SSH_KEY",
|
||||
"SUDO_PASSWORD",
|
||||
"GATEWAY_PROXY_KEY",
|
||||
"API_SERVER_KEY",
|
||||
"TOOL_GATEWAY_USER_TOKEN",
|
||||
"TELEGRAM_WEBHOOK_SECRET",
|
||||
"WEBHOOK_SECRET",
|
||||
"AI_GATEWAY_API_KEY",
|
||||
"VOICE_TOOLS_OPENAI_KEY",
|
||||
"BROWSER_USE_API_KEY",
|
||||
"CUSTOM_API_KEY",
|
||||
"GATEWAY_PROXY_URL",
|
||||
"GEMINI_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OLLAMA_BASE_URL",
|
||||
"GROQ_BASE_URL",
|
||||
"XAI_BASE_URL",
|
||||
"AI_GATEWAY_BASE_URL",
|
||||
"ANTHROPIC_BASE_URL",
|
||||
})
|
||||
|
||||
|
||||
def _looks_like_credential(name: str) -> bool:
|
||||
"""True if env var name matches a credential-shaped pattern."""
|
||||
if name in _CREDENTIAL_NAMES:
|
||||
return True
|
||||
return any(name.endswith(suf) for suf in _CREDENTIAL_SUFFIXES)
|
||||
|
||||
|
||||
# HERMES_* vars that change test behavior by being set. Unset all of these
|
||||
# unconditionally — individual tests that need them set do so explicitly.
|
||||
_HERMES_BEHAVIORAL_VARS = frozenset({
|
||||
"HERMES_YOLO_MODE",
|
||||
"HERMES_INTERACTIVE",
|
||||
"HERMES_QUIET",
|
||||
"HERMES_TOOL_PROGRESS",
|
||||
"HERMES_TOOL_PROGRESS_MODE",
|
||||
"HERMES_MAX_ITERATIONS",
|
||||
"HERMES_SESSION_PLATFORM",
|
||||
"HERMES_SESSION_CHAT_ID",
|
||||
"HERMES_SESSION_CHAT_NAME",
|
||||
"HERMES_SESSION_THREAD_ID",
|
||||
"HERMES_SESSION_SOURCE",
|
||||
"HERMES_SESSION_KEY",
|
||||
"HERMES_GATEWAY_SESSION",
|
||||
"HERMES_PLATFORM",
|
||||
"HERMES_INFERENCE_PROVIDER",
|
||||
"HERMES_MANAGED",
|
||||
"HERMES_DEV",
|
||||
"HERMES_CONTAINER",
|
||||
"HERMES_EPHEMERAL_SYSTEM_PROMPT",
|
||||
"HERMES_TIMEZONE",
|
||||
"HERMES_REDACT_SECRETS",
|
||||
"HERMES_BACKGROUND_NOTIFICATIONS",
|
||||
"HERMES_EXEC_ASK",
|
||||
"HERMES_HOME_MODE",
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _hermetic_environment(tmp_path, monkeypatch):
|
||||
"""Blank out all credential/behavioral env vars so local and CI match.
|
||||
|
||||
Also redirects HOME and HERMES_HOME to per-test tempdirs so code that
|
||||
reads ``~/.hermes/*`` can't touch the real one, and pins TZ/LANG so
|
||||
datetime/locale-sensitive tests are deterministic.
|
||||
"""
|
||||
# 1. Blank every credential-shaped env var that's currently set.
|
||||
for name in list(os.environ.keys()):
|
||||
if _looks_like_credential(name):
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
# 2. Blank behavioral HERMES_* vars that could change test semantics.
|
||||
for name in _HERMES_BEHAVIORAL_VARS:
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
# 3. Redirect HERMES_HOME to a per-test tempdir. Code that reads
|
||||
# ``~/.hermes/*`` via ``get_hermes_home()`` now gets the tempdir.
|
||||
#
|
||||
# NOTE: We do NOT also redirect HOME. Doing so broke CI because
|
||||
# some tests (and their transitive deps) spawn subprocesses that
|
||||
# inherit HOME and expect it to be stable. If a test genuinely
|
||||
# needs HOME isolated, it should set it explicitly in its own
|
||||
# fixture. Any code in the codebase reading ``~/.hermes/*`` via
|
||||
# ``Path.home() / ".hermes"`` instead of ``get_hermes_home()``
|
||||
# is a bug to fix at the callsite.
|
||||
fake_hermes_home = tmp_path / "hermes_test"
|
||||
fake_hermes_home.mkdir()
|
||||
(fake_hermes_home / "sessions").mkdir()
|
||||
(fake_hermes_home / "cron").mkdir()
|
||||
(fake_hermes_home / "memories").mkdir()
|
||||
(fake_hermes_home / "skills").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(fake_hermes_home))
|
||||
|
||||
# 4. Deterministic locale / timezone / hashseed. CI runs in UTC with
|
||||
# C.UTF-8 locale; local dev often doesn't. Pin everything.
|
||||
monkeypatch.setenv("TZ", "UTC")
|
||||
monkeypatch.setenv("LANG", "C.UTF-8")
|
||||
monkeypatch.setenv("LC_ALL", "C.UTF-8")
|
||||
monkeypatch.setenv("PYTHONHASHSEED", "0")
|
||||
|
||||
# 5. Reset plugin singleton so tests don't leak plugins from
|
||||
# ~/.hermes/plugins/ (which, per step 3, is now empty — but the
|
||||
# singleton might still be cached from a previous test).
|
||||
def _isolate_hermes_home(tmp_path, monkeypatch):
|
||||
"""Redirect HERMES_HOME to a temp dir so tests never write to ~/.hermes/."""
|
||||
fake_home = tmp_path / "hermes_test"
|
||||
fake_home.mkdir()
|
||||
(fake_home / "sessions").mkdir()
|
||||
(fake_home / "cron").mkdir()
|
||||
(fake_home / "memories").mkdir()
|
||||
(fake_home / "skills").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(fake_home))
|
||||
# Reset plugin singleton so tests don't leak plugins from ~/.hermes/plugins/
|
||||
try:
|
||||
import hermes_cli.plugins as _plugins_mod
|
||||
monkeypatch.setattr(_plugins_mod, "_plugin_manager", None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Backward-compat alias — old tests reference this fixture name. Keep it
|
||||
# as a no-op wrapper so imports don't break.
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_hermes_home(_hermetic_environment):
|
||||
"""Alias preserved for any test that yields this name explicitly."""
|
||||
return None
|
||||
# Tests should not inherit the agent's current gateway/messaging surface.
|
||||
# Individual tests that need gateway behavior set these explicitly.
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
# Avoid making real calls during tests if this key is set in the env files
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@@ -64,60 +64,6 @@ class TestResolveDeliveryTarget:
|
||||
"thread_id": "17585",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("platform", "env_var", "chat_id"),
|
||||
[
|
||||
("matrix", "MATRIX_HOME_ROOM", "!bot-room:example.org"),
|
||||
("signal", "SIGNAL_HOME_CHANNEL", "+15551234567"),
|
||||
("mattermost", "MATTERMOST_HOME_CHANNEL", "team-town-square"),
|
||||
("sms", "SMS_HOME_CHANNEL", "+15557654321"),
|
||||
("email", "EMAIL_HOME_ADDRESS", "home@example.com"),
|
||||
("dingtalk", "DINGTALK_HOME_CHANNEL", "cidNNN"),
|
||||
("feishu", "FEISHU_HOME_CHANNEL", "oc_home"),
|
||||
("wecom", "WECOM_HOME_CHANNEL", "wecom-home"),
|
||||
("weixin", "WEIXIN_HOME_CHANNEL", "wxid_home"),
|
||||
("qqbot", "QQ_HOME_CHANNEL", "group-openid-home"),
|
||||
],
|
||||
)
|
||||
def test_origin_delivery_without_origin_falls_back_to_supported_home_channels(
|
||||
self, monkeypatch, platform, env_var, chat_id
|
||||
):
|
||||
for fallback_env in (
|
||||
"MATRIX_HOME_ROOM",
|
||||
"MATRIX_HOME_CHANNEL",
|
||||
"TELEGRAM_HOME_CHANNEL",
|
||||
"DISCORD_HOME_CHANNEL",
|
||||
"SLACK_HOME_CHANNEL",
|
||||
"SIGNAL_HOME_CHANNEL",
|
||||
"MATTERMOST_HOME_CHANNEL",
|
||||
"SMS_HOME_CHANNEL",
|
||||
"EMAIL_HOME_ADDRESS",
|
||||
"DINGTALK_HOME_CHANNEL",
|
||||
"BLUEBUBBLES_HOME_CHANNEL",
|
||||
"FEISHU_HOME_CHANNEL",
|
||||
"WECOM_HOME_CHANNEL",
|
||||
"WEIXIN_HOME_CHANNEL",
|
||||
"QQ_HOME_CHANNEL",
|
||||
):
|
||||
monkeypatch.delenv(fallback_env, raising=False)
|
||||
monkeypatch.setenv(env_var, chat_id)
|
||||
|
||||
assert _resolve_delivery_target({"deliver": "origin"}) == {
|
||||
"platform": platform,
|
||||
"chat_id": chat_id,
|
||||
"thread_id": None,
|
||||
}
|
||||
|
||||
def test_bare_matrix_delivery_uses_matrix_home_room(self, monkeypatch):
|
||||
monkeypatch.delenv("MATRIX_HOME_CHANNEL", raising=False)
|
||||
monkeypatch.setenv("MATRIX_HOME_ROOM", "!room123:example.org")
|
||||
|
||||
assert _resolve_delivery_target({"deliver": "matrix"}) == {
|
||||
"platform": "matrix",
|
||||
"chat_id": "!room123:example.org",
|
||||
"thread_id": None,
|
||||
}
|
||||
|
||||
def test_explicit_telegram_topic_target_with_thread_id(self):
|
||||
"""deliver: 'telegram:chat_id:thread_id' parses correctly."""
|
||||
job = {
|
||||
|
||||
@@ -258,785 +258,3 @@ class TestAgentCacheLifecycle:
|
||||
cb3 = lambda *a: None
|
||||
agent.tool_progress_callback = cb3
|
||||
assert agent.tool_progress_callback is cb3
|
||||
|
||||
|
||||
class TestAgentCacheBoundedGrowth:
|
||||
"""LRU cap and idle-TTL eviction prevent unbounded cache growth."""
|
||||
|
||||
def _bounded_runner(self):
|
||||
"""Runner with an OrderedDict cache (matches real gateway init)."""
|
||||
from collections import OrderedDict
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._agent_cache = OrderedDict()
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
return runner
|
||||
|
||||
def _fake_agent(self, last_activity: float | None = None):
|
||||
"""Lightweight stand-in; real AIAgent is heavy to construct."""
|
||||
m = MagicMock()
|
||||
if last_activity is not None:
|
||||
m._last_activity_ts = last_activity
|
||||
else:
|
||||
import time as _t
|
||||
m._last_activity_ts = _t.time()
|
||||
return m
|
||||
|
||||
def test_cap_evicts_lru_when_exceeded(self, monkeypatch):
|
||||
"""Inserting past _AGENT_CACHE_MAX_SIZE pops the oldest entry."""
|
||||
from gateway import run as gw_run
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", 3)
|
||||
runner = self._bounded_runner()
|
||||
runner._cleanup_agent_resources = MagicMock()
|
||||
|
||||
for i in range(3):
|
||||
runner._agent_cache[f"s{i}"] = (self._fake_agent(), f"sig{i}")
|
||||
|
||||
# Insert a 4th — oldest (s0) must be evicted.
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["s3"] = (self._fake_agent(), "sig3")
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
assert "s0" not in runner._agent_cache
|
||||
assert "s3" in runner._agent_cache
|
||||
assert len(runner._agent_cache) == 3
|
||||
|
||||
def test_cap_respects_move_to_end(self, monkeypatch):
|
||||
"""Entries refreshed via move_to_end are NOT evicted as 'oldest'."""
|
||||
from gateway import run as gw_run
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", 3)
|
||||
runner = self._bounded_runner()
|
||||
runner._cleanup_agent_resources = MagicMock()
|
||||
|
||||
for i in range(3):
|
||||
runner._agent_cache[f"s{i}"] = (self._fake_agent(), f"sig{i}")
|
||||
|
||||
# Touch s0 — it is now MRU, so s1 becomes LRU.
|
||||
runner._agent_cache.move_to_end("s0")
|
||||
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["s3"] = (self._fake_agent(), "sig3")
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
assert "s0" in runner._agent_cache # rescued by move_to_end
|
||||
assert "s1" not in runner._agent_cache # now oldest → evicted
|
||||
assert "s3" in runner._agent_cache
|
||||
|
||||
def test_cap_triggers_cleanup_thread(self, monkeypatch):
|
||||
"""Evicted agent has release_clients() called for it (soft cleanup).
|
||||
|
||||
Uses the soft path (_release_evicted_agent_soft), NOT the hard
|
||||
_cleanup_agent_resources — cache eviction must not tear down
|
||||
per-task state (terminal/browser/bg procs).
|
||||
"""
|
||||
from gateway import run as gw_run
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", 1)
|
||||
runner = self._bounded_runner()
|
||||
|
||||
release_calls: list = []
|
||||
cleanup_calls: list = []
|
||||
# Intercept both paths; only release_clients path should fire.
|
||||
def _soft(agent):
|
||||
release_calls.append(agent)
|
||||
runner._release_evicted_agent_soft = _soft
|
||||
runner._cleanup_agent_resources = lambda a: cleanup_calls.append(a)
|
||||
|
||||
old_agent = self._fake_agent()
|
||||
new_agent = self._fake_agent()
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["old"] = (old_agent, "sig_old")
|
||||
runner._agent_cache["new"] = (new_agent, "sig_new")
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
# Cleanup is dispatched to a daemon thread; join briefly to observe.
|
||||
import time as _t
|
||||
deadline = _t.time() + 2.0
|
||||
while _t.time() < deadline and not release_calls:
|
||||
_t.sleep(0.02)
|
||||
assert old_agent in release_calls
|
||||
assert new_agent not in release_calls
|
||||
# Hard-cleanup path must NOT have fired — that's for session expiry only.
|
||||
assert cleanup_calls == []
|
||||
|
||||
def test_idle_ttl_sweep_evicts_stale_agents(self, monkeypatch):
|
||||
"""_sweep_idle_cached_agents removes agents idle past the TTL."""
|
||||
from gateway import run as gw_run
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_IDLE_TTL_SECS", 0.05)
|
||||
runner = self._bounded_runner()
|
||||
runner._cleanup_agent_resources = MagicMock()
|
||||
|
||||
import time as _t
|
||||
fresh = self._fake_agent(last_activity=_t.time())
|
||||
stale = self._fake_agent(last_activity=_t.time() - 10.0)
|
||||
runner._agent_cache["fresh"] = (fresh, "s1")
|
||||
runner._agent_cache["stale"] = (stale, "s2")
|
||||
|
||||
evicted = runner._sweep_idle_cached_agents()
|
||||
assert evicted == 1
|
||||
assert "stale" not in runner._agent_cache
|
||||
assert "fresh" in runner._agent_cache
|
||||
|
||||
def test_idle_sweep_skips_agents_without_activity_ts(self, monkeypatch):
|
||||
"""Agents missing _last_activity_ts are left alone (defensive)."""
|
||||
from gateway import run as gw_run
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_IDLE_TTL_SECS", 0.01)
|
||||
runner = self._bounded_runner()
|
||||
runner._cleanup_agent_resources = MagicMock()
|
||||
|
||||
no_ts = MagicMock(spec=[]) # no _last_activity_ts attribute
|
||||
runner._agent_cache["s"] = (no_ts, "sig")
|
||||
|
||||
assert runner._sweep_idle_cached_agents() == 0
|
||||
assert "s" in runner._agent_cache
|
||||
|
||||
def test_plain_dict_cache_is_tolerated(self):
|
||||
"""Test fixtures using plain {} don't crash _enforce_agent_cache_cap."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._agent_cache = {} # plain dict, not OrderedDict
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
runner._cleanup_agent_resources = MagicMock()
|
||||
|
||||
# Should be a no-op rather than raising.
|
||||
with runner._agent_cache_lock:
|
||||
for i in range(200):
|
||||
runner._agent_cache[f"s{i}"] = (MagicMock(), f"sig{i}")
|
||||
runner._enforce_agent_cache_cap() # no crash, no eviction
|
||||
|
||||
assert len(runner._agent_cache) == 200
|
||||
|
||||
def test_main_lookup_updates_lru_order(self, monkeypatch):
|
||||
"""Cache hit via the main-lookup path refreshes LRU position."""
|
||||
runner = self._bounded_runner()
|
||||
|
||||
a0 = self._fake_agent()
|
||||
a1 = self._fake_agent()
|
||||
a2 = self._fake_agent()
|
||||
runner._agent_cache["s0"] = (a0, "sig0")
|
||||
runner._agent_cache["s1"] = (a1, "sig1")
|
||||
runner._agent_cache["s2"] = (a2, "sig2")
|
||||
|
||||
# Simulate what _process_message_background does on a cache hit
|
||||
# (minus the agent-state reset which isn't relevant here).
|
||||
with runner._agent_cache_lock:
|
||||
cached = runner._agent_cache.get("s0")
|
||||
if cached and hasattr(runner._agent_cache, "move_to_end"):
|
||||
runner._agent_cache.move_to_end("s0")
|
||||
|
||||
# After the hit, insertion order should be s1, s2, s0.
|
||||
assert list(runner._agent_cache.keys()) == ["s1", "s2", "s0"]
|
||||
|
||||
|
||||
class TestAgentCacheActiveSafety:
|
||||
"""Safety: eviction must not tear down agents currently mid-turn.
|
||||
|
||||
AIAgent.close() kills process_registry entries for the task, cleans
|
||||
the terminal sandbox, closes the OpenAI client, and cascades
|
||||
.close() into active child subagents. Calling it while the agent
|
||||
is still processing would crash the in-flight request. These tests
|
||||
pin that eviction skips any agent present in _running_agents.
|
||||
"""
|
||||
|
||||
def _runner(self):
|
||||
from collections import OrderedDict
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._agent_cache = OrderedDict()
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
runner._running_agents = {}
|
||||
return runner
|
||||
|
||||
def _fake_agent(self, idle_seconds: float = 0.0):
|
||||
import time as _t
|
||||
m = MagicMock()
|
||||
m._last_activity_ts = _t.time() - idle_seconds
|
||||
return m
|
||||
|
||||
def test_cap_skips_active_lru_entry(self, monkeypatch):
|
||||
"""Active LRU entry is skipped; cache stays over cap rather than
|
||||
compensating by evicting a newer entry.
|
||||
|
||||
Rationale: evicting a more-recent entry just because the oldest
|
||||
slot is temporarily locked would punish the most recently-
|
||||
inserted session (which has no cache to preserve) to protect
|
||||
one that happens to be mid-turn. Better to let the cache stay
|
||||
transiently over cap and re-check on the next insert.
|
||||
"""
|
||||
from gateway import run as gw_run
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", 2)
|
||||
runner = self._runner()
|
||||
runner._cleanup_agent_resources = MagicMock()
|
||||
|
||||
active = self._fake_agent()
|
||||
idle_a = self._fake_agent()
|
||||
idle_b = self._fake_agent()
|
||||
|
||||
# Insertion order: active (oldest), idle_a, idle_b.
|
||||
runner._agent_cache["session-active"] = (active, "sig")
|
||||
runner._agent_cache["session-idle-a"] = (idle_a, "sig")
|
||||
runner._agent_cache["session-idle-b"] = (idle_b, "sig")
|
||||
|
||||
# Mark `active` as mid-turn — it's LRU, but protected.
|
||||
runner._running_agents["session-active"] = active
|
||||
|
||||
with runner._agent_cache_lock:
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
# All three remain; no eviction ran, no cleanup dispatched.
|
||||
assert "session-active" in runner._agent_cache
|
||||
assert "session-idle-a" in runner._agent_cache
|
||||
assert "session-idle-b" in runner._agent_cache
|
||||
assert runner._cleanup_agent_resources.call_count == 0
|
||||
|
||||
def test_cap_evicts_when_multiple_excess_and_some_inactive(self, monkeypatch):
|
||||
"""Mixed active/idle in the LRU excess window: only the idle ones go.
|
||||
|
||||
With CAP=2 and 4 entries, excess=2 (the two oldest). If the
|
||||
oldest is active and the next is idle, we evict exactly one.
|
||||
Cache ends at CAP+1, which is still better than unbounded.
|
||||
"""
|
||||
from gateway import run as gw_run
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", 2)
|
||||
runner = self._runner()
|
||||
runner._cleanup_agent_resources = MagicMock()
|
||||
|
||||
oldest_active = self._fake_agent()
|
||||
idle_second = self._fake_agent()
|
||||
idle_third = self._fake_agent()
|
||||
idle_fourth = self._fake_agent()
|
||||
|
||||
runner._agent_cache["s1"] = (oldest_active, "sig")
|
||||
runner._agent_cache["s2"] = (idle_second, "sig") # in excess window, idle
|
||||
runner._agent_cache["s3"] = (idle_third, "sig")
|
||||
runner._agent_cache["s4"] = (idle_fourth, "sig")
|
||||
|
||||
runner._running_agents["s1"] = oldest_active # oldest is mid-turn
|
||||
|
||||
with runner._agent_cache_lock:
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
# s1 protected (active), s2 evicted (idle + in excess window),
|
||||
# s3 and s4 untouched (outside excess window).
|
||||
assert "s1" in runner._agent_cache
|
||||
assert "s2" not in runner._agent_cache
|
||||
assert "s3" in runner._agent_cache
|
||||
assert "s4" in runner._agent_cache
|
||||
|
||||
def test_cap_leaves_cache_over_limit_if_all_active(self, monkeypatch, caplog):
|
||||
"""If every over-cap entry is mid-turn, the cache stays over cap.
|
||||
|
||||
Better to temporarily exceed the cap than to crash an in-flight
|
||||
turn by tearing down its clients.
|
||||
"""
|
||||
from gateway import run as gw_run
|
||||
import logging as _logging
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", 1)
|
||||
runner = self._runner()
|
||||
runner._cleanup_agent_resources = MagicMock()
|
||||
|
||||
a1 = self._fake_agent()
|
||||
a2 = self._fake_agent()
|
||||
a3 = self._fake_agent()
|
||||
runner._agent_cache["s1"] = (a1, "sig")
|
||||
runner._agent_cache["s2"] = (a2, "sig")
|
||||
runner._agent_cache["s3"] = (a3, "sig")
|
||||
|
||||
# All three are mid-turn.
|
||||
runner._running_agents["s1"] = a1
|
||||
runner._running_agents["s2"] = a2
|
||||
runner._running_agents["s3"] = a3
|
||||
|
||||
with caplog.at_level(_logging.WARNING, logger="gateway.run"):
|
||||
with runner._agent_cache_lock:
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
# Cache unchanged because eviction had to skip every candidate.
|
||||
assert len(runner._agent_cache) == 3
|
||||
# _cleanup_agent_resources must NOT have been scheduled.
|
||||
assert runner._cleanup_agent_resources.call_count == 0
|
||||
# And we logged a warning so operators can see the condition.
|
||||
assert any("mid-turn" in r.message for r in caplog.records)
|
||||
|
||||
def test_cap_pending_sentinel_does_not_block_eviction(self, monkeypatch):
|
||||
"""_AGENT_PENDING_SENTINEL in _running_agents is treated as 'not active'.
|
||||
|
||||
The sentinel is set while an agent is being CONSTRUCTED, before the
|
||||
real AIAgent instance exists. Cached agents from other sessions
|
||||
can still be evicted safely.
|
||||
"""
|
||||
from gateway import run as gw_run
|
||||
from gateway.run import _AGENT_PENDING_SENTINEL
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", 1)
|
||||
runner = self._runner()
|
||||
runner._cleanup_agent_resources = MagicMock()
|
||||
|
||||
a1 = self._fake_agent()
|
||||
a2 = self._fake_agent()
|
||||
runner._agent_cache["s1"] = (a1, "sig")
|
||||
runner._agent_cache["s2"] = (a2, "sig")
|
||||
# Another session is mid-creation — sentinel, no real agent yet.
|
||||
runner._running_agents["s3-being-created"] = _AGENT_PENDING_SENTINEL
|
||||
|
||||
with runner._agent_cache_lock:
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
assert "s1" not in runner._agent_cache # evicted normally
|
||||
assert "s2" in runner._agent_cache
|
||||
|
||||
def test_idle_sweep_skips_active_agent(self, monkeypatch):
|
||||
"""Idle-TTL sweep must not tear down an active agent even if 'stale'."""
|
||||
from gateway import run as gw_run
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_IDLE_TTL_SECS", 0.01)
|
||||
runner = self._runner()
|
||||
runner._cleanup_agent_resources = MagicMock()
|
||||
|
||||
old_but_active = self._fake_agent(idle_seconds=10.0)
|
||||
runner._agent_cache["s1"] = (old_but_active, "sig")
|
||||
runner._running_agents["s1"] = old_but_active
|
||||
|
||||
evicted = runner._sweep_idle_cached_agents()
|
||||
|
||||
assert evicted == 0
|
||||
assert "s1" in runner._agent_cache
|
||||
assert runner._cleanup_agent_resources.call_count == 0
|
||||
|
||||
def test_eviction_does_not_close_active_agent_client(self, monkeypatch):
|
||||
"""Live test: evicting an active agent does NOT null its .client.
|
||||
|
||||
This reproduces the original concern — if eviction fired while an
|
||||
agent was mid-turn, `agent.close()` would set `self.client = None`
|
||||
and the next API call inside the loop would crash. With the
|
||||
active-agent skip, the client stays intact.
|
||||
"""
|
||||
from gateway import run as gw_run
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", 1)
|
||||
runner = self._runner()
|
||||
|
||||
# Build a proper fake agent whose close() matches AIAgent's contract.
|
||||
active = MagicMock()
|
||||
active._last_activity_ts = __import__("time").time()
|
||||
active.client = MagicMock() # simulate an OpenAI client
|
||||
def _real_close():
|
||||
active.client = None # mirrors run_agent.py:3299
|
||||
active.close = _real_close
|
||||
active.shutdown_memory_provider = MagicMock()
|
||||
|
||||
idle = self._fake_agent()
|
||||
|
||||
runner._agent_cache["active-session"] = (active, "sig")
|
||||
runner._agent_cache["idle-session"] = (idle, "sig")
|
||||
runner._running_agents["active-session"] = active
|
||||
|
||||
# Real cleanup function, not mocked — we want to see whether close()
|
||||
# runs on the active agent. (It shouldn't.)
|
||||
with runner._agent_cache_lock:
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
# Let any eviction cleanup threads drain.
|
||||
import time as _t
|
||||
_t.sleep(0.2)
|
||||
|
||||
# The ACTIVE agent's client must still be usable.
|
||||
assert active.client is not None, (
|
||||
"Active agent's client was closed by eviction — "
|
||||
"running turn would crash on its next API call."
|
||||
)
|
||||
|
||||
|
||||
class TestAgentCacheSpilloverLive:
|
||||
"""Live E2E: fill cache with real AIAgent instances and stress it."""
|
||||
|
||||
def _runner(self):
|
||||
from collections import OrderedDict
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._agent_cache = OrderedDict()
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
runner._running_agents = {}
|
||||
return runner
|
||||
|
||||
def _real_agent(self):
|
||||
"""A genuine AIAgent; no API calls are made during these tests."""
|
||||
from run_agent import AIAgent
|
||||
return AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True,
|
||||
skip_context_files=True, skip_memory=True,
|
||||
platform="telegram",
|
||||
)
|
||||
|
||||
def test_fill_to_cap_then_spillover(self, monkeypatch):
|
||||
"""Fill to cap with real agents, insert one more, oldest evicted."""
|
||||
from gateway import run as gw_run
|
||||
|
||||
CAP = 8
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", CAP)
|
||||
runner = self._runner()
|
||||
|
||||
agents = [self._real_agent() for _ in range(CAP)]
|
||||
for i, a in enumerate(agents):
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache[f"s{i}"] = (a, "sig")
|
||||
runner._enforce_agent_cache_cap()
|
||||
assert len(runner._agent_cache) == CAP
|
||||
|
||||
# Spillover insertion.
|
||||
newcomer = self._real_agent()
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["new"] = (newcomer, "sig")
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
# Oldest (s0) evicted, cap still CAP.
|
||||
assert "s0" not in runner._agent_cache
|
||||
assert "new" in runner._agent_cache
|
||||
assert len(runner._agent_cache) == CAP
|
||||
|
||||
# Clean up so pytest doesn't leak resources.
|
||||
for a in agents + [newcomer]:
|
||||
try:
|
||||
a.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_spillover_all_active_keeps_cache_over_cap(self, monkeypatch, caplog):
|
||||
"""Every slot active: cache goes over cap, no one gets torn down."""
|
||||
from gateway import run as gw_run
|
||||
import logging as _logging
|
||||
|
||||
CAP = 4
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", CAP)
|
||||
runner = self._runner()
|
||||
|
||||
agents = [self._real_agent() for _ in range(CAP)]
|
||||
for i, a in enumerate(agents):
|
||||
runner._agent_cache[f"s{i}"] = (a, "sig")
|
||||
runner._running_agents[f"s{i}"] = a # every session mid-turn
|
||||
|
||||
newcomer = self._real_agent()
|
||||
with caplog.at_level(_logging.WARNING, logger="gateway.run"):
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["new"] = (newcomer, "sig")
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
assert len(runner._agent_cache) == CAP + 1 # temporarily over cap
|
||||
# All existing agents still usable.
|
||||
for i, a in enumerate(agents):
|
||||
assert a.client is not None, f"s{i} got closed while active!"
|
||||
# And we warned operators.
|
||||
assert any("mid-turn" in r.message for r in caplog.records)
|
||||
|
||||
for a in agents + [newcomer]:
|
||||
try:
|
||||
a.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_concurrent_inserts_settle_at_cap(self, monkeypatch):
|
||||
"""Many threads inserting in parallel end with len(cache) == CAP."""
|
||||
from gateway import run as gw_run
|
||||
|
||||
CAP = 16
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", CAP)
|
||||
runner = self._runner()
|
||||
|
||||
N_THREADS = 8
|
||||
PER_THREAD = 20 # 8 * 20 = 160 inserts into a 16-slot cache
|
||||
|
||||
def worker(tid: int):
|
||||
for j in range(PER_THREAD):
|
||||
a = self._real_agent()
|
||||
key = f"t{tid}-s{j}"
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache[key] = (a, "sig")
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=worker, args=(t,), daemon=True)
|
||||
for t in range(N_THREADS)
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=30)
|
||||
assert not t.is_alive(), "Worker thread hung — possible deadlock?"
|
||||
|
||||
# Let daemon cleanup threads settle.
|
||||
import time as _t
|
||||
_t.sleep(0.5)
|
||||
|
||||
assert len(runner._agent_cache) == CAP, (
|
||||
f"Expected exactly {CAP} entries after concurrent inserts, "
|
||||
f"got {len(runner._agent_cache)}."
|
||||
)
|
||||
|
||||
def test_evicted_session_next_turn_gets_fresh_agent(self, monkeypatch):
|
||||
"""After eviction, the same session_key can insert a fresh agent.
|
||||
|
||||
Simulates the real spillover flow: evicted session sends another
|
||||
message, which builds a new AIAgent and re-enters the cache.
|
||||
"""
|
||||
from gateway import run as gw_run
|
||||
|
||||
CAP = 2
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_MAX_SIZE", CAP)
|
||||
runner = self._runner()
|
||||
|
||||
a0 = self._real_agent()
|
||||
a1 = self._real_agent()
|
||||
runner._agent_cache["sA"] = (a0, "sig")
|
||||
runner._agent_cache["sB"] = (a1, "sig")
|
||||
|
||||
# 3rd session forces sA (oldest) out.
|
||||
a2 = self._real_agent()
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["sC"] = (a2, "sig")
|
||||
runner._enforce_agent_cache_cap()
|
||||
assert "sA" not in runner._agent_cache
|
||||
|
||||
# Let the eviction cleanup thread run.
|
||||
import time as _t
|
||||
_t.sleep(0.3)
|
||||
|
||||
# Now sA's user sends another message → a fresh agent goes in.
|
||||
a0_new = self._real_agent()
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["sA"] = (a0_new, "sig")
|
||||
runner._enforce_agent_cache_cap()
|
||||
|
||||
assert "sA" in runner._agent_cache
|
||||
assert runner._agent_cache["sA"][0] is a0_new # the new one, not stale
|
||||
# Fresh agent is usable.
|
||||
assert a0_new.client is not None
|
||||
|
||||
for a in (a0, a1, a2, a0_new):
|
||||
try:
|
||||
a.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class TestAgentCacheIdleResume:
|
||||
"""End-to-end: idle-TTL-evicted session resumes cleanly with task state.
|
||||
|
||||
Real-world scenario: user leaves a Telegram session open for 2+ hours.
|
||||
Idle-TTL evicts their cached agent. They come back and send a message.
|
||||
The new agent built for the same session_id must inherit:
|
||||
- Conversation history (from SessionStore — outside cache concern)
|
||||
- Terminal sandbox (same task_id → same _active_environments entry)
|
||||
- Browser daemon (same task_id → same browser session)
|
||||
- Background processes (same task_id → same process_registry entries)
|
||||
The ONLY thing that should reset is the LLM client pool (rebuilt fresh).
|
||||
"""
|
||||
|
||||
def _runner(self):
|
||||
from collections import OrderedDict
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._agent_cache = OrderedDict()
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
runner._running_agents = {}
|
||||
return runner
|
||||
|
||||
def test_release_clients_does_not_touch_process_registry(self, monkeypatch):
|
||||
"""release_clients must not call process_registry.kill_all for task_id."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True,
|
||||
skip_context_files=True, skip_memory=True,
|
||||
session_id="idle-resume-test-session",
|
||||
)
|
||||
|
||||
# Spy on process_registry.kill_all — it MUST NOT be called.
|
||||
from tools import process_registry as _pr
|
||||
kill_all_calls: list = []
|
||||
original_kill_all = _pr.process_registry.kill_all
|
||||
_pr.process_registry.kill_all = lambda **kw: kill_all_calls.append(kw)
|
||||
try:
|
||||
agent.release_clients()
|
||||
finally:
|
||||
_pr.process_registry.kill_all = original_kill_all
|
||||
try:
|
||||
agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assert kill_all_calls == [], (
|
||||
f"release_clients() called process_registry.kill_all — would "
|
||||
f"kill user's bg processes on cache eviction. Calls: {kill_all_calls}"
|
||||
)
|
||||
|
||||
def test_release_clients_does_not_touch_terminal_or_browser(self, monkeypatch):
|
||||
"""release_clients must not call cleanup_vm or cleanup_browser."""
|
||||
from run_agent import AIAgent
|
||||
from tools import terminal_tool as _tt
|
||||
from tools import browser_tool as _bt
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True,
|
||||
skip_context_files=True, skip_memory=True,
|
||||
session_id="idle-resume-test-2",
|
||||
)
|
||||
|
||||
vm_calls: list = []
|
||||
browser_calls: list = []
|
||||
original_vm = _tt.cleanup_vm
|
||||
original_browser = _bt.cleanup_browser
|
||||
_tt.cleanup_vm = lambda tid: vm_calls.append(tid)
|
||||
_bt.cleanup_browser = lambda tid: browser_calls.append(tid)
|
||||
try:
|
||||
agent.release_clients()
|
||||
finally:
|
||||
_tt.cleanup_vm = original_vm
|
||||
_bt.cleanup_browser = original_browser
|
||||
try:
|
||||
agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assert vm_calls == [], (
|
||||
f"release_clients() tore down terminal sandbox — user's cwd, "
|
||||
f"env, and bg shells would be gone on resume. Calls: {vm_calls}"
|
||||
)
|
||||
assert browser_calls == [], (
|
||||
f"release_clients() tore down browser session — user's open "
|
||||
f"tabs and cookies gone on resume. Calls: {browser_calls}"
|
||||
)
|
||||
|
||||
def test_release_clients_closes_llm_client(self):
|
||||
"""release_clients IS expected to close the OpenAI/httpx client."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True,
|
||||
skip_context_files=True, skip_memory=True,
|
||||
)
|
||||
# Clients are lazy-built; force one to exist so we can verify close.
|
||||
assert agent.client is not None # __init__ builds it
|
||||
|
||||
agent.release_clients()
|
||||
|
||||
# Post-release: client reference is dropped (memory freed).
|
||||
assert agent.client is None
|
||||
|
||||
def test_close_vs_release_full_teardown_difference(self, monkeypatch):
|
||||
"""close() tears down task state; release_clients() does not.
|
||||
|
||||
This pins the semantic contract: session-expiry path uses close()
|
||||
(full teardown — session is done), cache-eviction path uses
|
||||
release_clients() (soft — session may resume).
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
from tools import terminal_tool as _tt
|
||||
|
||||
# Agent A: evicted from cache (soft) — terminal survives.
|
||||
# Agent B: session expired (hard) — terminal torn down.
|
||||
agent_a = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True,
|
||||
skip_context_files=True, skip_memory=True,
|
||||
session_id="soft-session",
|
||||
)
|
||||
agent_b = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True,
|
||||
skip_context_files=True, skip_memory=True,
|
||||
session_id="hard-session",
|
||||
)
|
||||
|
||||
vm_calls: list = []
|
||||
original_vm = _tt.cleanup_vm
|
||||
_tt.cleanup_vm = lambda tid: vm_calls.append(tid)
|
||||
try:
|
||||
agent_a.release_clients() # cache eviction
|
||||
agent_b.close() # session expiry
|
||||
finally:
|
||||
_tt.cleanup_vm = original_vm
|
||||
try:
|
||||
agent_a.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Only agent_b's task_id should appear in cleanup calls.
|
||||
assert "hard-session" in vm_calls
|
||||
assert "soft-session" not in vm_calls
|
||||
|
||||
def test_idle_evicted_session_rebuild_inherits_task_id(self, monkeypatch):
|
||||
"""After idle-TTL eviction, a fresh agent with the same session_id
|
||||
gets the same task_id — so tool state (terminal/browser/bg procs)
|
||||
that persisted across eviction is reachable via the new agent.
|
||||
"""
|
||||
from gateway import run as gw_run
|
||||
from run_agent import AIAgent
|
||||
|
||||
monkeypatch.setattr(gw_run, "_AGENT_CACHE_IDLE_TTL_SECS", 0.01)
|
||||
runner = self._runner()
|
||||
|
||||
# Build an agent representing a stale (idle) session.
|
||||
SESSION_ID = "long-lived-user-session"
|
||||
old = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True,
|
||||
skip_context_files=True, skip_memory=True,
|
||||
session_id=SESSION_ID,
|
||||
)
|
||||
old._last_activity_ts = 0.0 # force idle
|
||||
runner._agent_cache["sKey"] = (old, "sig")
|
||||
|
||||
# Simulate the idle-TTL sweep firing.
|
||||
runner._sweep_idle_cached_agents()
|
||||
assert "sKey" not in runner._agent_cache
|
||||
|
||||
# Wait for the daemon thread doing release_clients() to finish.
|
||||
import time as _t
|
||||
_t.sleep(0.3)
|
||||
|
||||
# Old agent's client is gone (soft cleanup fired).
|
||||
assert old.client is None
|
||||
|
||||
# User comes back — new agent built for the SAME session_id.
|
||||
new_agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True,
|
||||
skip_context_files=True, skip_memory=True,
|
||||
session_id=SESSION_ID,
|
||||
)
|
||||
|
||||
# Same session_id means same task_id routed to tools. The new
|
||||
# agent inherits any per-task state (terminal sandbox etc.) that
|
||||
# was preserved across eviction.
|
||||
assert new_agent.session_id == old.session_id == SESSION_ID
|
||||
# And it has a fresh working client.
|
||||
assert new_agent.client is not None
|
||||
|
||||
try:
|
||||
new_agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -71,51 +71,6 @@ class TestGetConnectedPlatforms:
|
||||
config = GatewayConfig()
|
||||
assert config.get_connected_platforms() == []
|
||||
|
||||
def test_dingtalk_recognised_via_extras(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DINGTALK: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"client_id": "cid", "client_secret": "sec"},
|
||||
),
|
||||
},
|
||||
)
|
||||
assert Platform.DINGTALK in config.get_connected_platforms()
|
||||
|
||||
def test_dingtalk_recognised_via_env_vars(self, monkeypatch):
|
||||
"""DingTalk configured via env vars (no extras) should still be
|
||||
recognised as connected — covers the case where _apply_env_overrides
|
||||
hasn't populated extras yet."""
|
||||
monkeypatch.setenv("DINGTALK_CLIENT_ID", "env_cid")
|
||||
monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "env_sec")
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DINGTALK: PlatformConfig(enabled=True, extra={}),
|
||||
},
|
||||
)
|
||||
assert Platform.DINGTALK in config.get_connected_platforms()
|
||||
|
||||
def test_dingtalk_missing_creds_not_connected(self, monkeypatch):
|
||||
monkeypatch.delenv("DINGTALK_CLIENT_ID", raising=False)
|
||||
monkeypatch.delenv("DINGTALK_CLIENT_SECRET", raising=False)
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DINGTALK: PlatformConfig(enabled=True, extra={}),
|
||||
},
|
||||
)
|
||||
assert Platform.DINGTALK not in config.get_connected_platforms()
|
||||
|
||||
def test_dingtalk_disabled_not_connected(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DINGTALK: PlatformConfig(
|
||||
enabled=False,
|
||||
extra={"client_id": "cid", "client_secret": "sec"},
|
||||
),
|
||||
},
|
||||
)
|
||||
assert Platform.DINGTALK not in config.get_connected_platforms()
|
||||
|
||||
|
||||
class TestSessionResetPolicy:
|
||||
def test_roundtrip(self):
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
@@ -231,29 +230,6 @@ class TestSend:
|
||||
|
||||
class TestConnect:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_closes_session_websocket(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
websocket = AsyncMock()
|
||||
blocker = asyncio.Event()
|
||||
|
||||
async def _run_forever():
|
||||
try:
|
||||
await blocker.wait()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
adapter._stream_client = SimpleNamespace(websocket=websocket)
|
||||
adapter._stream_task = asyncio.create_task(_run_forever())
|
||||
adapter._running = True
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
websocket.close.assert_awaited_once()
|
||||
assert adapter._stream_task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_fails_without_sdk(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
@@ -421,263 +397,3 @@ class TestExtractText:
|
||||
msg.rich_text = None
|
||||
assert DingTalkAdapter._extract_text(msg) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group gating — require_mention + allowed_users (parity with other platforms)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_gating_adapter(monkeypatch, *, extra=None, env=None):
|
||||
"""Build a DingTalkAdapter with only the gating fields populated.
|
||||
|
||||
Clears every DINGTALK_* gating env var before applying the caller's
|
||||
overrides so individual tests stay isolated.
|
||||
"""
|
||||
for key in (
|
||||
"DINGTALK_REQUIRE_MENTION",
|
||||
"DINGTALK_MENTION_PATTERNS",
|
||||
"DINGTALK_FREE_RESPONSE_CHATS",
|
||||
"DINGTALK_ALLOWED_USERS",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
for key, value in (env or {}).items():
|
||||
monkeypatch.setenv(key, value)
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
return DingTalkAdapter(PlatformConfig(enabled=True, extra=extra or {}))
|
||||
|
||||
|
||||
class TestAllowedUsersGate:
|
||||
|
||||
def test_empty_allowlist_allows_everyone(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(monkeypatch)
|
||||
assert adapter._is_user_allowed("anyone", "any-staff") is True
|
||||
|
||||
def test_wildcard_allowlist_allows_everyone(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(monkeypatch, extra={"allowed_users": ["*"]})
|
||||
assert adapter._is_user_allowed("anyone", "any-staff") is True
|
||||
|
||||
def test_matches_sender_id_case_insensitive(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch, extra={"allowed_users": ["SenderABC"]}
|
||||
)
|
||||
assert adapter._is_user_allowed("senderabc", "") is True
|
||||
|
||||
def test_matches_staff_id(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch, extra={"allowed_users": ["staff_1234"]}
|
||||
)
|
||||
assert adapter._is_user_allowed("", "staff_1234") is True
|
||||
|
||||
def test_rejects_unknown_user(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch, extra={"allowed_users": ["staff_1234"]}
|
||||
)
|
||||
assert adapter._is_user_allowed("other-sender", "other-staff") is False
|
||||
|
||||
def test_env_var_csv_populates_allowlist(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch, env={"DINGTALK_ALLOWED_USERS": "alice,bob,carol"}
|
||||
)
|
||||
assert adapter._is_user_allowed("alice", "") is True
|
||||
assert adapter._is_user_allowed("dave", "") is False
|
||||
|
||||
|
||||
class TestMentionPatterns:
|
||||
|
||||
def test_empty_patterns_list(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(monkeypatch)
|
||||
assert adapter._mention_patterns == []
|
||||
assert adapter._message_matches_mention_patterns("anything") is False
|
||||
|
||||
def test_pattern_matches_text(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch, extra={"mention_patterns": ["^hermes"]}
|
||||
)
|
||||
assert adapter._message_matches_mention_patterns("hermes please help") is True
|
||||
assert adapter._message_matches_mention_patterns("please hermes help") is False
|
||||
|
||||
def test_pattern_is_case_insensitive(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch, extra={"mention_patterns": ["^hermes"]}
|
||||
)
|
||||
assert adapter._message_matches_mention_patterns("HERMES help") is True
|
||||
|
||||
def test_invalid_regex_is_skipped_not_raised(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch,
|
||||
extra={"mention_patterns": ["[unclosed", "^valid"]},
|
||||
)
|
||||
# Invalid pattern dropped, valid one kept
|
||||
assert len(adapter._mention_patterns) == 1
|
||||
assert adapter._message_matches_mention_patterns("valid trigger") is True
|
||||
|
||||
def test_env_var_json_populates_patterns(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch,
|
||||
env={"DINGTALK_MENTION_PATTERNS": '["^bot", "^assistant"]'},
|
||||
)
|
||||
assert len(adapter._mention_patterns) == 2
|
||||
assert adapter._message_matches_mention_patterns("bot ping") is True
|
||||
|
||||
def test_env_var_newline_fallback_when_not_json(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch,
|
||||
env={"DINGTALK_MENTION_PATTERNS": "^bot\n^assistant"},
|
||||
)
|
||||
assert len(adapter._mention_patterns) == 2
|
||||
|
||||
|
||||
class TestShouldProcessMessage:
|
||||
|
||||
def test_dm_always_accepted(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch, extra={"require_mention": True}
|
||||
)
|
||||
msg = MagicMock(is_in_at_list=False)
|
||||
assert adapter._should_process_message(msg, "hi", is_group=False, chat_id="dm1") is True
|
||||
|
||||
def test_group_rejected_when_require_mention_and_no_trigger(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch, extra={"require_mention": True}
|
||||
)
|
||||
msg = MagicMock(is_in_at_list=False)
|
||||
assert adapter._should_process_message(msg, "hi", is_group=True, chat_id="grp1") is False
|
||||
|
||||
def test_group_accepted_when_require_mention_disabled(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch, extra={"require_mention": False}
|
||||
)
|
||||
msg = MagicMock(is_in_at_list=False)
|
||||
assert adapter._should_process_message(msg, "hi", is_group=True, chat_id="grp1") is True
|
||||
|
||||
def test_group_accepted_when_bot_is_mentioned(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch, extra={"require_mention": True}
|
||||
)
|
||||
msg = MagicMock(is_in_at_list=True)
|
||||
assert adapter._should_process_message(msg, "hi", is_group=True, chat_id="grp1") is True
|
||||
|
||||
def test_group_accepted_when_text_matches_wake_word(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch,
|
||||
extra={"require_mention": True, "mention_patterns": ["^hermes"]},
|
||||
)
|
||||
msg = MagicMock(is_in_at_list=False)
|
||||
assert adapter._should_process_message(msg, "hermes help", is_group=True, chat_id="grp1") is True
|
||||
|
||||
def test_group_accepted_when_chat_in_free_response_list(self, monkeypatch):
|
||||
adapter = _make_gating_adapter(
|
||||
monkeypatch,
|
||||
extra={"require_mention": True, "free_response_chats": ["grp1"]},
|
||||
)
|
||||
msg = MagicMock(is_in_at_list=False)
|
||||
assert adapter._should_process_message(msg, "hi", is_group=True, chat_id="grp1") is True
|
||||
# Different group still blocked
|
||||
assert adapter._should_process_message(msg, "hi", is_group=True, chat_id="grp2") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _IncomingHandler.process — session_webhook extraction & fire-and-forget
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIncomingHandlerProcess:
|
||||
"""Verify that _IncomingHandler.process correctly converts callback data
|
||||
and dispatches message processing as a background task (fire-and-forget)
|
||||
so the SDK ACK is returned immediately."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_extracts_session_webhook(self):
|
||||
"""session_webhook must be populated from callback data."""
|
||||
from gateway.platforms.dingtalk import _IncomingHandler, DingTalkAdapter
|
||||
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._on_message = AsyncMock()
|
||||
handler = _IncomingHandler(adapter, asyncio.get_running_loop())
|
||||
|
||||
callback = MagicMock()
|
||||
callback.data = {
|
||||
"msgtype": "text",
|
||||
"text": {"content": "hello"},
|
||||
"senderId": "user1",
|
||||
"conversationId": "conv1",
|
||||
"sessionWebhook": "https://oapi.dingtalk.com/robot/sendBySession?session=abc",
|
||||
"msgId": "msg-001",
|
||||
}
|
||||
|
||||
result = await handler.process(callback)
|
||||
# Should return ACK immediately (STATUS_OK = 200)
|
||||
assert result[0] == 200
|
||||
|
||||
# Let the background task run
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# _on_message should have been called with a ChatbotMessage
|
||||
adapter._on_message.assert_called_once()
|
||||
chatbot_msg = adapter._on_message.call_args[0][0]
|
||||
assert chatbot_msg.session_webhook == "https://oapi.dingtalk.com/robot/sendBySession?session=abc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_fallback_session_webhook_when_from_dict_misses_it(self):
|
||||
"""If ChatbotMessage.from_dict does not map sessionWebhook (e.g. SDK
|
||||
version mismatch), the handler should fall back to extracting it
|
||||
directly from the raw data dict."""
|
||||
from gateway.platforms.dingtalk import _IncomingHandler, DingTalkAdapter
|
||||
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._on_message = AsyncMock()
|
||||
handler = _IncomingHandler(adapter, asyncio.get_running_loop())
|
||||
|
||||
callback = MagicMock()
|
||||
# Use a key that from_dict might not recognise in some SDK versions
|
||||
callback.data = {
|
||||
"msgtype": "text",
|
||||
"text": {"content": "hi"},
|
||||
"senderId": "user2",
|
||||
"conversationId": "conv2",
|
||||
"session_webhook": "https://oapi.dingtalk.com/robot/sendBySession?session=def",
|
||||
"msgId": "msg-002",
|
||||
}
|
||||
|
||||
await handler.process(callback)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
adapter._on_message.assert_called_once()
|
||||
chatbot_msg = adapter._on_message.call_args[0][0]
|
||||
assert chatbot_msg.session_webhook == "https://oapi.dingtalk.com/robot/sendBySession?session=def"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_returns_ack_immediately(self):
|
||||
"""process() must not block on _on_message — it should return
|
||||
the ACK tuple before the message is fully processed."""
|
||||
from gateway.platforms.dingtalk import _IncomingHandler, DingTalkAdapter
|
||||
|
||||
processing_started = asyncio.Event()
|
||||
processing_gate = asyncio.Event()
|
||||
|
||||
async def slow_on_message(msg):
|
||||
processing_started.set()
|
||||
await processing_gate.wait() # Block until we release
|
||||
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._on_message = slow_on_message
|
||||
handler = _IncomingHandler(adapter, asyncio.get_running_loop())
|
||||
|
||||
callback = MagicMock()
|
||||
callback.data = {
|
||||
"msgtype": "text",
|
||||
"text": {"content": "test"},
|
||||
"senderId": "u",
|
||||
"conversationId": "c",
|
||||
"sessionWebhook": "https://oapi.dingtalk.com/x",
|
||||
"msgId": "m",
|
||||
}
|
||||
|
||||
# process() should return immediately even though _on_message blocks
|
||||
result = await handler.process(callback)
|
||||
assert result[0] == 200
|
||||
|
||||
# Clean up: release the gate so the background task finishes
|
||||
processing_gate.set()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
"""Tests for the Discord ``allowed_mentions`` safe-default helper.
|
||||
|
||||
Ensures the bot defaults to blocking ``@everyone`` / ``@here`` / role pings
|
||||
so an LLM response (or echoed user content) can't spam a whole server —
|
||||
and that the four ``DISCORD_ALLOW_MENTION_*`` env vars correctly opt back
|
||||
in when an operator explicitly wants a different policy.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeAllowedMentions:
|
||||
"""Stand-in for ``discord.AllowedMentions`` that exposes the same four
|
||||
boolean flags as real attributes so the test can assert on them.
|
||||
"""
|
||||
|
||||
def __init__(self, *, everyone=True, roles=True, users=True, replied_user=True):
|
||||
self.everyone = everyone
|
||||
self.roles = roles
|
||||
self.users = users
|
||||
self.replied_user = replied_user
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - debug helper
|
||||
return (
|
||||
f"AllowedMentions(everyone={self.everyone}, roles={self.roles}, "
|
||||
f"users={self.users}, replied_user={self.replied_user})"
|
||||
)
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install (or augment) a mock ``discord`` module.
|
||||
|
||||
Other test modules in this directory stub ``discord`` via
|
||||
``sys.modules.setdefault`` — whichever test file imports first wins and
|
||||
our full module is then silently dropped. We therefore ALWAYS force
|
||||
``AllowedMentions`` onto whatever is currently in ``sys.modules["discord"]``;
|
||||
that's the only attribute this test file actually needs real behavior from.
|
||||
"""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
sys.modules["discord"].AllowedMentions = _FakeAllowedMentions
|
||||
return
|
||||
|
||||
if sys.modules.get("discord") is None:
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3, grey=4, secondary=5)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
discord_mod.opus = SimpleNamespace(is_loaded=lambda: True)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules["discord"] = discord_mod
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
# Whether we just installed the mock OR the mock was already installed
|
||||
# by another test's _ensure_discord_mock, force the AllowedMentions
|
||||
# stand-in onto it — _build_allowed_mentions() reads this attribute.
|
||||
sys.modules["discord"].AllowedMentions = _FakeAllowedMentions
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import _build_allowed_mentions # noqa: E402
|
||||
|
||||
|
||||
# The four DISCORD_ALLOW_MENTION_* env vars that _build_allowed_mentions reads.
|
||||
# Cleared before each test so env leakage from other tests never masks a regression.
|
||||
_ENV_VARS = (
|
||||
"DISCORD_ALLOW_MENTION_EVERYONE",
|
||||
"DISCORD_ALLOW_MENTION_ROLES",
|
||||
"DISCORD_ALLOW_MENTION_USERS",
|
||||
"DISCORD_ALLOW_MENTION_REPLIED_USER",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_allowed_mention_env(monkeypatch):
|
||||
for name in _ENV_VARS:
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
|
||||
def test_safe_defaults_block_everyone_and_roles():
|
||||
am = _build_allowed_mentions()
|
||||
assert am.everyone is False, "default must NOT allow @everyone/@here pings"
|
||||
assert am.roles is False, "default must NOT allow role pings"
|
||||
assert am.users is True, "default must allow user pings so replies work"
|
||||
assert am.replied_user is True, "default must allow reply-reference pings"
|
||||
|
||||
|
||||
def test_env_var_opts_back_into_everyone(monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_ALLOW_MENTION_EVERYONE", "true")
|
||||
am = _build_allowed_mentions()
|
||||
assert am.everyone is True
|
||||
# other defaults unaffected
|
||||
assert am.roles is False
|
||||
assert am.users is True
|
||||
assert am.replied_user is True
|
||||
|
||||
|
||||
def test_env_var_can_disable_users(monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_ALLOW_MENTION_USERS", "false")
|
||||
am = _build_allowed_mentions()
|
||||
assert am.users is False
|
||||
# safe defaults elsewhere remain
|
||||
assert am.everyone is False
|
||||
assert am.roles is False
|
||||
assert am.replied_user is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raw, expected", [
|
||||
("true", True), ("True", True), ("TRUE", True),
|
||||
("1", True), ("yes", True), ("YES", True), ("on", True),
|
||||
("false", False), ("False", False), ("0", False),
|
||||
("no", False), ("off", False),
|
||||
("", False), # empty falls back to default (False for everyone)
|
||||
("garbage", False), # unknown falls back to default
|
||||
(" true ", True), # whitespace tolerated
|
||||
])
|
||||
def test_everyone_boolean_parsing(monkeypatch, raw, expected):
|
||||
monkeypatch.setenv("DISCORD_ALLOW_MENTION_EVERYONE", raw)
|
||||
am = _build_allowed_mentions()
|
||||
assert am.everyone is expected
|
||||
|
||||
|
||||
def test_all_four_knobs_together(monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_ALLOW_MENTION_EVERYONE", "true")
|
||||
monkeypatch.setenv("DISCORD_ALLOW_MENTION_ROLES", "true")
|
||||
monkeypatch.setenv("DISCORD_ALLOW_MENTION_USERS", "false")
|
||||
monkeypatch.setenv("DISCORD_ALLOW_MENTION_REPLIED_USER", "false")
|
||||
am = _build_allowed_mentions()
|
||||
assert am.everyone is True
|
||||
assert am.roles is True
|
||||
assert am.users is False
|
||||
assert am.replied_user is False
|
||||
@@ -1,360 +0,0 @@
|
||||
"""Tests for Discord attachment downloads via the authenticated bot session.
|
||||
|
||||
Covers the three download paths (image / audio / document) in
|
||||
``DiscordAdapter._handle_message()`` and the shared ``_cache_discord_*``
|
||||
helpers. Verifies that:
|
||||
|
||||
- ``att.read()`` is preferred over the legacy URL-based downloaders so
|
||||
that Discord's CDN auth (and user-environment DNS quirks) can't block
|
||||
media caching. (issues #8242 image 403s, #6587 CDN SSRF false-positives)
|
||||
- Falls back cleanly to the SSRF-gated ``cache_*_from_url`` helpers
|
||||
(image/audio) or SSRF-gated aiohttp (documents) when ``att.read()``
|
||||
isn't available or fails.
|
||||
- The document fallback path now runs through the SSRF gate for
|
||||
defense-in-depth. (issue #11345)
|
||||
"""
|
||||
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
# Minimal valid image / audio / PDF bytes so the cache_*_from_bytes
|
||||
# validators accept them. cache_image_from_bytes runs _looks_like_image()
|
||||
# which checks for magic bytes; PNG's magic is sufficient.
|
||||
_PNG_BYTES = b"\x89PNG\r\n\x1a\n" + b"\x00" * 64
|
||||
_OGG_BYTES = b"OggS" + b"\x00" * 60
|
||||
_PDF_BYTES = b"%PDF-1.4\n" + b"fake pdf body" + b"\n%%EOF"
|
||||
|
||||
|
||||
def _make_adapter() -> DiscordAdapter:
|
||||
return DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
|
||||
def _make_attachment_with_read(payload: bytes) -> SimpleNamespace:
|
||||
"""Attachment stub that exposes .read() — the happy-path primary."""
|
||||
return SimpleNamespace(
|
||||
url="https://cdn.discordapp.com/attachments/fake/file.png",
|
||||
filename="file.png",
|
||||
size=len(payload),
|
||||
read=AsyncMock(return_value=payload),
|
||||
)
|
||||
|
||||
|
||||
def _make_attachment_without_read() -> SimpleNamespace:
|
||||
"""Attachment stub that has no .read() — exercises the URL fallback."""
|
||||
return SimpleNamespace(
|
||||
url="https://cdn.discordapp.com/attachments/fake/file.png",
|
||||
filename="file.png",
|
||||
size=1024,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_attachment_bytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadAttachmentBytes:
|
||||
"""Unit tests for the low-level att.read() wrapper."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_bytes_on_successful_read(self):
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_with_read(b"hello world")
|
||||
|
||||
result = await adapter._read_attachment_bytes(att)
|
||||
|
||||
assert result == b"hello world"
|
||||
att.read.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_read_missing(self):
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_without_read()
|
||||
|
||||
result = await adapter._read_attachment_bytes(att)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_read_raises(self):
|
||||
"""Bot-session fetch failures are swallowed so callers fall back."""
|
||||
adapter = _make_adapter()
|
||||
att = SimpleNamespace(
|
||||
url="https://cdn.discordapp.com/attachments/fake/file.png",
|
||||
filename="file.png",
|
||||
read=AsyncMock(side_effect=RuntimeError("403 Forbidden")),
|
||||
)
|
||||
|
||||
result = await adapter._read_attachment_bytes(att)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cache_discord_image
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheDiscordImage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefers_att_read_over_url(self):
|
||||
"""Primary path: att.read() bytes → cache_image_from_bytes, no URL fetch."""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_with_read(_PNG_BYTES)
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
return_value="/tmp/cached.png",
|
||||
) as mock_bytes, patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_image(att, ".png")
|
||||
|
||||
assert result == "/tmp/cached.png"
|
||||
mock_bytes.assert_called_once_with(_PNG_BYTES, ext=".png")
|
||||
mock_url.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_url_when_no_read(self):
|
||||
"""No .read() → URL path is used (existing SSRF-gated behavior)."""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_without_read()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
) as mock_bytes, patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/from_url.png",
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_image(att, ".png")
|
||||
|
||||
assert result == "/tmp/from_url.png"
|
||||
mock_bytes.assert_not_called()
|
||||
mock_url.assert_awaited_once_with(att.url, ext=".png")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_url_when_bytes_validator_rejects(self):
|
||||
"""If att.read() returns garbage that cache_image_from_bytes rejects
|
||||
(e.g. an HTML error page), fall back to the URL downloader instead
|
||||
of surfacing the validation error to the caller."""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_with_read(b"<html>forbidden</html>")
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
side_effect=ValueError("not a valid image"),
|
||||
), patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/fallback.png",
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_image(att, ".png")
|
||||
|
||||
assert result == "/tmp/fallback.png"
|
||||
mock_url.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cache_discord_audio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheDiscordAudio:
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefers_att_read_over_url(self):
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_with_read(_OGG_BYTES)
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_audio_from_bytes",
|
||||
return_value="/tmp/voice.ogg",
|
||||
) as mock_bytes, patch(
|
||||
"gateway.platforms.discord.cache_audio_from_url",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_audio(att, ".ogg")
|
||||
|
||||
assert result == "/tmp/voice.ogg"
|
||||
mock_bytes.assert_called_once_with(_OGG_BYTES, ext=".ogg")
|
||||
mock_url.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_url_when_no_read(self):
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_without_read()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_audio_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/from_url.ogg",
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_audio(att, ".ogg")
|
||||
|
||||
assert result == "/tmp/from_url.ogg"
|
||||
mock_url.assert_awaited_once_with(att.url, ext=".ogg")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cache_discord_document
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheDiscordDocument:
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefers_att_read_returns_bytes_directly(self):
|
||||
"""Primary path: att.read() → raw bytes, no aiohttp involvement."""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_with_read(_PDF_BYTES)
|
||||
|
||||
with patch("aiohttp.ClientSession") as mock_session:
|
||||
result = await adapter._cache_discord_document(att, ".pdf")
|
||||
|
||||
assert result == _PDF_BYTES
|
||||
mock_session.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_blocked_by_ssrf_guard(self):
|
||||
"""Document fallback path now honors is_safe_url — was missing before.
|
||||
|
||||
Regression guard for #11345: the old aiohttp block skipped the
|
||||
SSRF check entirely; a non-CDN ``att.url`` could have reached
|
||||
internal-looking hosts. The fallback must now refuse unsafe URLs.
|
||||
"""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_without_read() # no .read → forces fallback
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.is_safe_url", return_value=False
|
||||
) as mock_safe, patch("aiohttp.ClientSession") as mock_session:
|
||||
with pytest.raises(ValueError, match="SSRF"):
|
||||
await adapter._cache_discord_document(att, ".pdf")
|
||||
|
||||
mock_safe.assert_called_once_with(att.url)
|
||||
# aiohttp must NOT be contacted when the URL is blocked.
|
||||
mock_session.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_aiohttp_when_safe_url(self):
|
||||
"""Safe URL + no att.read() → aiohttp fallback executes."""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_without_read()
|
||||
|
||||
# Build an aiohttp session mock that returns 200 + payload.
|
||||
resp = AsyncMock()
|
||||
resp.status = 200
|
||||
resp.read = AsyncMock(return_value=_PDF_BYTES)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = AsyncMock()
|
||||
session.get = MagicMock(return_value=resp)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.is_safe_url", return_value=True
|
||||
), patch("aiohttp.ClientSession", return_value=session):
|
||||
result = await adapter._cache_discord_document(att, ".pdf")
|
||||
|
||||
assert result == _PDF_BYTES
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: end-to-end via _handle_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHandleMessageUsesAuthenticatedRead:
|
||||
"""E2E: verify _handle_message routes image/audio downloads through
|
||||
att.read() so cdn.discordapp.com 403s (#8242) and SSRF false-positives
|
||||
on mangled DNS (#6587) no longer block media caching.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_downloads_via_att_read_not_url(self, monkeypatch):
|
||||
"""Image attachments with .read() never call cache_image_from_url."""
|
||||
adapter = _make_adapter()
|
||||
adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
return_value="/tmp/img_from_read.png",
|
||||
), patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_url_download:
|
||||
att = SimpleNamespace(
|
||||
url="https://cdn.discordapp.com/attachments/fake/file.png",
|
||||
filename="file.png",
|
||||
content_type="image/png",
|
||||
size=len(_PNG_BYTES),
|
||||
read=AsyncMock(return_value=_PNG_BYTES),
|
||||
)
|
||||
# Minimal Discord message stub for _handle_message.
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class _FakeDMChannel:
|
||||
id = 100
|
||||
name = "dm"
|
||||
|
||||
# Patch the DMChannel isinstance check so our fake counts as DM.
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.discord.discord.DMChannel",
|
||||
_FakeDMChannel,
|
||||
)
|
||||
chan = _FakeDMChannel()
|
||||
msg = SimpleNamespace(
|
||||
id=1, content="", attachments=[att], mentions=[],
|
||||
reference=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=chan,
|
||||
author=SimpleNamespace(id=42, display_name="U", name="U"),
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
mock_url_download.assert_not_called()
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == ["/tmp/img_from_read.png"]
|
||||
assert event.media_types == ["image/png"]
|
||||
@@ -1,226 +0,0 @@
|
||||
"""Regression guard for #4466: DISCORD_ALLOW_BOTS works without DISCORD_ALLOWED_USERS.
|
||||
|
||||
The bug had two sequential gates both rejecting bot messages:
|
||||
|
||||
Gate 1 — `on_message` in gateway/platforms/discord.py ran the user-allowlist
|
||||
check BEFORE the bot filter, so bot senders were dropped with a warning
|
||||
before the DISCORD_ALLOW_BOTS policy was ever evaluated.
|
||||
|
||||
Gate 2 — `_is_user_authorized` in gateway/run.py rejected bots at the
|
||||
gateway level even if they somehow reached that layer.
|
||||
|
||||
These tests assert both gates now pass a bot message through when
|
||||
DISCORD_ALLOW_BOTS permits it AND no user allowlist entry exists.
|
||||
"""
|
||||
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.session import Platform, SessionSource
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_discord_env(monkeypatch):
|
||||
"""Make every test start with a clean Discord env so prior tests in the
|
||||
session (or CI setups) can't leak DISCORD_ALLOWED_ROLES / DISCORD_ALLOWED_USERS
|
||||
/ DISCORD_ALLOW_BOTS and silently flip the auth result.
|
||||
"""
|
||||
for var in (
|
||||
"DISCORD_ALLOW_BOTS",
|
||||
"DISCORD_ALLOWED_USERS",
|
||||
"DISCORD_ALLOWED_ROLES",
|
||||
"DISCORD_ALLOW_ALL_USERS",
|
||||
"GATEWAY_ALLOW_ALL_USERS",
|
||||
"GATEWAY_ALLOWED_USERS",
|
||||
):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Gate 2: _is_user_authorized bypasses allowlist for permitted bots
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bare_runner():
|
||||
"""Build a GatewayRunner skeleton with just enough wiring for the auth test.
|
||||
|
||||
Uses ``object.__new__`` to skip the heavy __init__ — many gateway tests
|
||||
use this pattern (see AGENTS.md pitfall #17).
|
||||
"""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
# _is_user_authorized reads self.pairing_store.is_approved(...) before
|
||||
# any allowlist check succeeds; stub it to never approve so we exercise
|
||||
# the real allowlist path.
|
||||
runner.pairing_store = SimpleNamespace(is_approved=lambda *_a, **_kw: False)
|
||||
return runner
|
||||
|
||||
|
||||
def _make_discord_bot_source(bot_id: str = "999888777"):
|
||||
return SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="123",
|
||||
chat_type="channel",
|
||||
user_id=bot_id,
|
||||
user_name="SomeBot",
|
||||
is_bot=True,
|
||||
)
|
||||
|
||||
|
||||
def _make_discord_human_source(user_id: str = "100200300"):
|
||||
return SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="123",
|
||||
chat_type="channel",
|
||||
user_id=user_id,
|
||||
user_name="SomeHuman",
|
||||
is_bot=False,
|
||||
)
|
||||
|
||||
|
||||
def test_discord_bot_authorized_when_allow_bots_mentions(monkeypatch):
|
||||
"""DISCORD_ALLOW_BOTS=mentions must authorize a bot sender even when
|
||||
DISCORD_ALLOWED_USERS is set and the bot's ID is NOT in it.
|
||||
|
||||
This is the exact scenario from #4466 — a Cloudflare Worker webhook
|
||||
posts Notion events to Discord, the Hermes bot gets @mentioned, and
|
||||
the webhook's bot ID is not (and shouldn't be) on the human
|
||||
allowlist.
|
||||
"""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.setenv("DISCORD_ALLOW_BOTS", "mentions")
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_USERS", "100200300") # human-only allowlist
|
||||
|
||||
source = _make_discord_bot_source(bot_id="999888777")
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_discord_bot_authorized_when_allow_bots_all(monkeypatch):
|
||||
"""DISCORD_ALLOW_BOTS=all is a superset of =mentions — should also bypass."""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.setenv("DISCORD_ALLOW_BOTS", "all")
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_USERS", "100200300")
|
||||
|
||||
source = _make_discord_bot_source()
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_discord_bot_NOT_authorized_when_allow_bots_none(monkeypatch):
|
||||
"""DISCORD_ALLOW_BOTS=none (default) must still reject bots that aren't
|
||||
in DISCORD_ALLOWED_USERS — preserves the original security behavior.
|
||||
"""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.setenv("DISCORD_ALLOW_BOTS", "none")
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_USERS", "100200300")
|
||||
|
||||
source = _make_discord_bot_source(bot_id="999888777")
|
||||
assert runner._is_user_authorized(source) is False
|
||||
|
||||
|
||||
def test_discord_bot_NOT_authorized_when_allow_bots_unset(monkeypatch):
|
||||
"""Unset DISCORD_ALLOW_BOTS must behave like 'none'."""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.delenv("DISCORD_ALLOW_BOTS", raising=False)
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_USERS", "100200300")
|
||||
|
||||
source = _make_discord_bot_source(bot_id="999888777")
|
||||
assert runner._is_user_authorized(source) is False
|
||||
|
||||
|
||||
def test_discord_human_still_checked_against_allowlist_when_bot_policy_set(monkeypatch):
|
||||
"""DISCORD_ALLOW_BOTS=all must NOT open the gate for humans — they
|
||||
still need to be in DISCORD_ALLOWED_USERS (or a pairing approval).
|
||||
"""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.setenv("DISCORD_ALLOW_BOTS", "all")
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_USERS", "100200300")
|
||||
|
||||
# Human NOT on the allowlist → must be rejected.
|
||||
source = _make_discord_human_source(user_id="999999999")
|
||||
assert runner._is_user_authorized(source) is False
|
||||
|
||||
# Human ON the allowlist → accepted.
|
||||
source_allowed = _make_discord_human_source(user_id="100200300")
|
||||
assert runner._is_user_authorized(source_allowed) is True
|
||||
|
||||
|
||||
def test_bot_bypass_does_not_leak_to_other_platforms(monkeypatch):
|
||||
"""The is_bot bypass is Discord-specific — a Telegram bot source with
|
||||
is_bot=True must NOT be authorized just because DISCORD_ALLOW_BOTS=all.
|
||||
"""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.setenv("DISCORD_ALLOW_BOTS", "all")
|
||||
monkeypatch.setenv("TELEGRAM_ALLOWED_USERS", "100200300")
|
||||
|
||||
telegram_bot = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="channel",
|
||||
user_id="999888777",
|
||||
is_bot=True,
|
||||
)
|
||||
assert runner._is_user_authorized(telegram_bot) is False
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DISCORD_ALLOWED_ROLES gateway-layer bypass (#7871)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_discord_role_config_bypasses_gateway_allowlist(monkeypatch):
|
||||
"""When DISCORD_ALLOWED_ROLES is set, _is_user_authorized must trust
|
||||
the adapter's pre-filter and authorize. Without this, role-only setups
|
||||
(DISCORD_ALLOWED_ROLES populated, DISCORD_ALLOWED_USERS empty) would
|
||||
hit the 'no allowlists configured' branch and get rejected.
|
||||
"""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_ROLES", "1493705176387948674")
|
||||
# Note: DISCORD_ALLOWED_USERS is NOT set — the entire point.
|
||||
|
||||
source = _make_discord_human_source(user_id="999888777")
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_discord_role_config_still_authorizes_alongside_users(monkeypatch):
|
||||
"""Sanity: setting both DISCORD_ALLOWED_ROLES and DISCORD_ALLOWED_USERS
|
||||
doesn't break the user-id path. Users in the allowlist should still be
|
||||
authorized even if they don't have a role. (OR semantics.)
|
||||
"""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_ROLES", "1493705176387948674")
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_USERS", "100200300")
|
||||
|
||||
# User on the user allowlist, no role → still authorized at gateway
|
||||
# level via the role bypass (adapter already approved them).
|
||||
source = _make_discord_human_source(user_id="100200300")
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_discord_role_bypass_does_not_leak_to_other_platforms(monkeypatch):
|
||||
"""DISCORD_ALLOWED_ROLES must only affect Discord. Setting it should
|
||||
not suddenly start authorizing Telegram users whose platform has its
|
||||
own empty allowlist.
|
||||
"""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_ROLES", "1493705176387948674")
|
||||
# Telegram has its own empty allowlist and no allow-all flag.
|
||||
|
||||
telegram_user = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="channel",
|
||||
user_id="999888777",
|
||||
)
|
||||
assert runner._is_user_authorized(telegram_user) is False
|
||||
@@ -8,60 +8,37 @@ import pytest
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
class _FakeAllowedMentions:
|
||||
"""Stand-in for ``discord.AllowedMentions`` — exposes the same four
|
||||
boolean flags as real attributes so tests can assert on safe defaults.
|
||||
"""
|
||||
|
||||
def __init__(self, *, everyone=True, roles=True, users=True, replied_user=True):
|
||||
self.everyone = everyone
|
||||
self.roles = roles
|
||||
self.users = users
|
||||
self.replied_user = replied_user
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install (or augment) a mock ``discord`` module.
|
||||
|
||||
Always force ``AllowedMentions`` onto whatever is in ``sys.modules`` —
|
||||
other test files also stub the module via ``setdefault``, and we need
|
||||
``_build_allowed_mentions()``'s return value to have real attribute
|
||||
access regardless of which file loaded first.
|
||||
"""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
sys.modules["discord"].AllowedMentions = _FakeAllowedMentions
|
||||
return
|
||||
|
||||
if sys.modules.get("discord") is None:
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3, grey=4, secondary=5)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
discord_mod.opus = SimpleNamespace(is_loaded=lambda: True)
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3, grey=4, secondary=5)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
discord_mod.opus = SimpleNamespace(is_loaded=lambda: True)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules["discord"] = discord_mod
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
sys.modules["discord"].AllowedMentions = _FakeAllowedMentions
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
@@ -79,9 +56,8 @@ class FakeTree:
|
||||
|
||||
|
||||
class FakeBot:
|
||||
def __init__(self, *, intents, proxy=None, allowed_mentions=None, **_):
|
||||
def __init__(self, *, intents, proxy=None):
|
||||
self.intents = intents
|
||||
self.allowed_mentions = allowed_mentions
|
||||
self.user = SimpleNamespace(id=999, name="Hermes")
|
||||
self._events = {}
|
||||
self.tree = FakeTree()
|
||||
@@ -139,8 +115,8 @@ async def test_connect_only_requests_members_intent_when_needed(monkeypatch, all
|
||||
|
||||
created = {}
|
||||
|
||||
def fake_bot_factory(*, command_prefix, intents, proxy=None, allowed_mentions=None, **_):
|
||||
created["bot"] = FakeBot(intents=intents, allowed_mentions=allowed_mentions)
|
||||
def fake_bot_factory(*, command_prefix, intents, proxy=None):
|
||||
created["bot"] = FakeBot(intents=intents)
|
||||
return created["bot"]
|
||||
|
||||
monkeypatch.setattr(discord_platform.commands, "Bot", fake_bot_factory)
|
||||
@@ -150,13 +126,6 @@ async def test_connect_only_requests_members_intent_when_needed(monkeypatch, all
|
||||
|
||||
assert ok is True
|
||||
assert created["bot"].intents.members is expected_members_intent
|
||||
# Safe-default AllowedMentions must be applied on every connect so the
|
||||
# bot cannot @everyone from LLM output. Granular overrides live in the
|
||||
# dedicated test_discord_allowed_mentions.py module.
|
||||
am = created["bot"].allowed_mentions
|
||||
assert am is not None, "connect() must pass an AllowedMentions to commands.Bot"
|
||||
assert am.everyone is False
|
||||
assert am.roles is False
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
@@ -175,11 +144,7 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
discord_platform.commands,
|
||||
"Bot",
|
||||
lambda **kwargs: FakeBot(
|
||||
intents=kwargs["intents"],
|
||||
proxy=kwargs.get("proxy"),
|
||||
allowed_mentions=kwargs.get("allowed_mentions"),
|
||||
),
|
||||
lambda **kwargs: FakeBot(intents=kwargs["intents"], proxy=kwargs.get("proxy")),
|
||||
)
|
||||
|
||||
async def fake_wait_for(awaitable, timeout):
|
||||
@@ -207,7 +172,7 @@ async def test_connect_does_not_wait_for_slash_sync(monkeypatch):
|
||||
|
||||
created = {}
|
||||
|
||||
def fake_bot_factory(*, command_prefix, intents, proxy=None, allowed_mentions=None, **_):
|
||||
def fake_bot_factory(*, command_prefix, intents, proxy=None):
|
||||
bot = SlowSyncBot(intents=intents, proxy=proxy)
|
||||
created["bot"] = bot
|
||||
return bot
|
||||
|
||||
@@ -96,7 +96,7 @@ def adapter(monkeypatch):
|
||||
return adapter
|
||||
|
||||
|
||||
def make_message(*, channel, content: str, mentions=None, msg_type=None):
|
||||
def make_message(*, channel, content: str, mentions=None):
|
||||
author = SimpleNamespace(id=42, display_name="Jezza", name="Jezza")
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
@@ -107,7 +107,6 @@ def make_message(*, channel, content: str, mentions=None, msg_type=None):
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=channel,
|
||||
author=author,
|
||||
type=msg_type if msg_type is not None else discord_platform.discord.MessageType.default,
|
||||
)
|
||||
|
||||
|
||||
@@ -205,21 +204,6 @@ async def test_discord_free_response_channel_overrides_mention_requirement(adapt
|
||||
assert event.text == "allowed without mention"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_free_response_channel_can_come_from_config_extra(adapter, monkeypatch):
|
||||
monkeypatch.delenv("DISCORD_REQUIRE_MENTION", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
adapter.config.extra["free_response_channels"] = ["789", "999"]
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=789), content="allowed from config")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "allowed from config"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_forum_parent_in_free_response_list_allows_forum_thread(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
@@ -292,31 +276,6 @@ async def test_discord_auto_thread_enabled_by_default(adapter, monkeypatch):
|
||||
assert event.source.thread_id == "999"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_reply_message_skips_auto_thread(adapter, monkeypatch):
|
||||
"""Quote-replies should stay in-channel instead of trying to create a thread."""
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "123")
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
message = make_message(
|
||||
channel=FakeTextChannel(channel_id=123),
|
||||
content="reply without mention",
|
||||
msg_type=discord_platform.discord.MessageType.reply,
|
||||
)
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "reply without mention"
|
||||
assert event.source.chat_id == "123"
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_can_be_disabled(adapter, monkeypatch):
|
||||
"""Setting auto_thread to false skips thread creation."""
|
||||
@@ -426,33 +385,6 @@ async def test_discord_voice_linked_channel_skips_mention_requirement_and_auto_t
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_free_channel_skips_auto_thread(adapter, monkeypatch):
|
||||
"""Free-response channels must NOT auto-create threads — bot replies inline.
|
||||
|
||||
Without this, every message in a free-response channel would spin off a
|
||||
thread (since the channel bypasses the @mention gate), defeating the
|
||||
lightweight-chat purpose of free-response mode.
|
||||
"""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "789")
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False) # default true
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
message = make_message(
|
||||
channel=FakeTextChannel(channel_id=789),
|
||||
content="free chat message",
|
||||
)
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_voice_linked_parent_thread_still_requires_mention(adapter, monkeypatch):
|
||||
"""Threads under a voice-linked channel should still require @mention."""
|
||||
|
||||
@@ -105,14 +105,9 @@ def _make_discord_adapter(reply_to_mode: str = "first"):
|
||||
config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode)
|
||||
adapter = DiscordAdapter(config)
|
||||
|
||||
# Mock the Discord client and channel.
|
||||
# ref_message.to_reference() → a distinct sentinel: the adapter now wraps
|
||||
# the fetched Message via to_reference(fail_if_not_exists=False) so a
|
||||
# deleted target degrades to "send without reply chip" instead of a 400.
|
||||
# Mock the Discord client and channel
|
||||
mock_channel = AsyncMock()
|
||||
ref_message = MagicMock()
|
||||
ref_reference = MagicMock(name="MessageReference")
|
||||
ref_message.to_reference = MagicMock(return_value=ref_reference)
|
||||
mock_channel.fetch_message = AsyncMock(return_value=ref_message)
|
||||
|
||||
sent_msg = MagicMock()
|
||||
@@ -123,9 +118,7 @@ def _make_discord_adapter(reply_to_mode: str = "first"):
|
||||
mock_client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
||||
adapter._client = mock_client
|
||||
# Return the reference sentinel alongside so tests can assert identity.
|
||||
adapter._test_expected_reference = ref_reference
|
||||
return adapter, mock_channel, ref_reference
|
||||
return adapter, mock_channel, ref_message
|
||||
|
||||
|
||||
class TestSendWithReplyToMode:
|
||||
|
||||
@@ -48,8 +48,7 @@ from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
async def test_send_retries_without_reference_when_reply_target_is_system_message():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
reference_obj = object()
|
||||
ref_msg = SimpleNamespace(id=99, to_reference=MagicMock(return_value=reference_obj))
|
||||
ref_msg = SimpleNamespace(id=99)
|
||||
sent_msg = SimpleNamespace(id=1234)
|
||||
send_calls = []
|
||||
|
||||
@@ -77,83 +76,5 @@ async def test_send_retries_without_reference_when_reply_target_is_system_messag
|
||||
assert result.message_id == "1234"
|
||||
assert channel.fetch_message.await_count == 1
|
||||
assert channel.send.await_count == 2
|
||||
ref_msg.to_reference.assert_called_once_with(fail_if_not_exists=False)
|
||||
assert send_calls[0]["reference"] is reference_obj
|
||||
assert send_calls[0]["reference"] is ref_msg
|
||||
assert send_calls[1]["reference"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_without_reference_when_reply_target_is_deleted():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
reference_obj = object()
|
||||
ref_msg = SimpleNamespace(id=99, to_reference=MagicMock(return_value=reference_obj))
|
||||
sent_msgs = [SimpleNamespace(id=1001), SimpleNamespace(id=1002)]
|
||||
send_calls = []
|
||||
|
||||
async def fake_send(*, content, reference=None):
|
||||
send_calls.append({"content": content, "reference": reference})
|
||||
if len(send_calls) == 1:
|
||||
raise RuntimeError(
|
||||
"400 Bad Request (error code: 10008): Unknown Message"
|
||||
)
|
||||
return sent_msgs[len(send_calls) - 2]
|
||||
|
||||
channel = SimpleNamespace(
|
||||
fetch_message=AsyncMock(return_value=ref_msg),
|
||||
send=AsyncMock(side_effect=fake_send),
|
||||
)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
long_text = "A" * (adapter.MAX_MESSAGE_LENGTH + 50)
|
||||
result = await adapter.send("555", long_text, reply_to="99")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "1001"
|
||||
assert channel.fetch_message.await_count == 1
|
||||
assert channel.send.await_count == 3
|
||||
ref_msg.to_reference.assert_called_once_with(fail_if_not_exists=False)
|
||||
assert send_calls[0]["reference"] is reference_obj
|
||||
assert send_calls[1]["reference"] is None
|
||||
assert send_calls[2]["reference"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_does_not_retry_on_unrelated_errors():
|
||||
"""Regression guard: errors unrelated to the reply reference (e.g. 50013
|
||||
Missing Permissions) must NOT trigger the no-reference retry path — they
|
||||
should propagate out of the per-chunk loop and surface as a failed
|
||||
SendResult so the caller sees the real problem instead of a silent retry.
|
||||
"""
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
reference_obj = object()
|
||||
ref_msg = SimpleNamespace(id=99, to_reference=MagicMock(return_value=reference_obj))
|
||||
send_calls = []
|
||||
|
||||
async def fake_send(*, content, reference=None):
|
||||
send_calls.append({"content": content, "reference": reference})
|
||||
raise RuntimeError(
|
||||
"403 Forbidden (error code: 50013): Missing Permissions"
|
||||
)
|
||||
|
||||
channel = SimpleNamespace(
|
||||
fetch_message=AsyncMock(return_value=ref_msg),
|
||||
send=AsyncMock(side_effect=fake_send),
|
||||
)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("555", "hello", reply_to="99")
|
||||
|
||||
# Outer except in adapter.send() wraps propagated errors as SendResult.
|
||||
assert result.success is False
|
||||
assert "50013" in (result.error or "")
|
||||
# Only the first attempt happens — no reference-retry replay.
|
||||
assert channel.send.await_count == 1
|
||||
assert send_calls[0]["reference"] is reference_obj
|
||||
|
||||
@@ -11,66 +11,52 @@ from gateway.config import PlatformConfig
|
||||
|
||||
def _ensure_discord_mock():
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
# Real discord is installed — nothing to do.
|
||||
return
|
||||
|
||||
if sys.modules.get("discord") is None:
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.Interaction = object
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.Interaction = object
|
||||
|
||||
# Lightweight mock for app_commands.Group and Command used by
|
||||
# _register_skill_group.
|
||||
class _FakeGroup:
|
||||
def __init__(self, *, name, description, parent=None):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.parent = parent
|
||||
self._children: dict[str, object] = {}
|
||||
if parent is not None:
|
||||
parent.add_command(self)
|
||||
# Lightweight mock for app_commands.Group and Command used by
|
||||
# _register_skill_group.
|
||||
class _FakeGroup:
|
||||
def __init__(self, *, name, description, parent=None):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.parent = parent
|
||||
self._children: dict[str, object] = {}
|
||||
if parent is not None:
|
||||
parent.add_command(self)
|
||||
|
||||
def add_command(self, cmd):
|
||||
self._children[cmd.name] = cmd
|
||||
def add_command(self, cmd):
|
||||
self._children[cmd.name] = cmd
|
||||
|
||||
class _FakeCommand:
|
||||
def __init__(self, *, name, description, callback, parent=None):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.callback = callback
|
||||
self.parent = parent
|
||||
class _FakeCommand:
|
||||
def __init__(self, *, name, description, callback, parent=None):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.callback = callback
|
||||
self.parent = parent
|
||||
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
autocomplete=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
Group=_FakeGroup,
|
||||
Command=_FakeCommand,
|
||||
)
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
Group=_FakeGroup,
|
||||
Command=_FakeCommand,
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules["discord"] = discord_mod
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
# Whether we just installed the mock OR another test module installed
|
||||
# it first via its own _ensure_discord_mock, force the decorators we
|
||||
# need onto discord.app_commands — the flat /skill command uses
|
||||
# @app_commands.autocomplete and not every other mock stub exposes it.
|
||||
_app = getattr(sys.modules["discord"], "app_commands", None)
|
||||
if _app is not None and not hasattr(_app, "autocomplete"):
|
||||
try:
|
||||
_app.autocomplete = lambda **kwargs: (lambda fn: fn)
|
||||
except Exception:
|
||||
pass
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
@@ -401,8 +387,6 @@ async def test_auto_create_thread_uses_message_content_as_name(adapter):
|
||||
message = SimpleNamespace(
|
||||
content="Hello world, how are you?",
|
||||
create_thread=AsyncMock(return_value=thread),
|
||||
channel=SimpleNamespace(send=AsyncMock()),
|
||||
author=SimpleNamespace(display_name="Jezza"),
|
||||
)
|
||||
|
||||
result = await adapter._auto_create_thread(message)
|
||||
@@ -414,48 +398,6 @@ async def test_auto_create_thread_uses_message_content_as_name(adapter):
|
||||
assert call_kwargs["auto_archive_duration"] == 1440
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_create_thread_strips_mention_syntax_from_name(adapter):
|
||||
"""Thread names must not contain raw <@id>, <@&id>, or <#id> markers.
|
||||
|
||||
Regression guard for #6336 — previously a message like
|
||||
``<@&1490963422786093149> help`` would spawn a thread literally
|
||||
named ``<@&1490963422786093149> help``.
|
||||
"""
|
||||
thread = SimpleNamespace(id=999, name="help")
|
||||
message = SimpleNamespace(
|
||||
content="<@&1490963422786093149> <@555> please help <#123>",
|
||||
create_thread=AsyncMock(return_value=thread),
|
||||
channel=SimpleNamespace(send=AsyncMock()),
|
||||
author=SimpleNamespace(display_name="Jezza"),
|
||||
)
|
||||
|
||||
await adapter._auto_create_thread(message)
|
||||
|
||||
name = message.create_thread.await_args[1]["name"]
|
||||
assert "<@" not in name, f"role/user mention leaked: {name!r}"
|
||||
assert "<#" not in name, f"channel mention leaked: {name!r}"
|
||||
assert name == "please help"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_create_thread_falls_back_to_hermes_when_only_mentions(adapter):
|
||||
"""If a message contains only mention syntax, the stripped content is
|
||||
empty — fall back to the 'Hermes' default rather than ''."""
|
||||
thread = SimpleNamespace(id=999, name="Hermes")
|
||||
message = SimpleNamespace(
|
||||
content="<@&1490963422786093149>",
|
||||
create_thread=AsyncMock(return_value=thread),
|
||||
channel=SimpleNamespace(send=AsyncMock()),
|
||||
author=SimpleNamespace(display_name="Jezza"),
|
||||
)
|
||||
|
||||
await adapter._auto_create_thread(message)
|
||||
|
||||
name = message.create_thread.await_args[1]["name"]
|
||||
assert name == "Hermes"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_create_thread_truncates_long_names(adapter):
|
||||
long_text = "a" * 200
|
||||
@@ -463,8 +405,6 @@ async def test_auto_create_thread_truncates_long_names(adapter):
|
||||
message = SimpleNamespace(
|
||||
content=long_text,
|
||||
create_thread=AsyncMock(return_value=thread),
|
||||
channel=SimpleNamespace(send=AsyncMock()),
|
||||
author=SimpleNamespace(display_name="Jezza"),
|
||||
)
|
||||
|
||||
result = await adapter._auto_create_thread(message)
|
||||
@@ -476,33 +416,10 @@ async def test_auto_create_thread_truncates_long_names(adapter):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_create_thread_falls_back_to_seed_message(adapter):
|
||||
thread = SimpleNamespace(id=555, name="Hello")
|
||||
seed_message = SimpleNamespace(create_thread=AsyncMock(return_value=thread))
|
||||
async def test_auto_create_thread_returns_none_on_failure(adapter):
|
||||
message = SimpleNamespace(
|
||||
content="Hello",
|
||||
create_thread=AsyncMock(side_effect=RuntimeError("no perms")),
|
||||
channel=SimpleNamespace(send=AsyncMock(return_value=seed_message)),
|
||||
author=SimpleNamespace(display_name="Jezza"),
|
||||
)
|
||||
|
||||
result = await adapter._auto_create_thread(message)
|
||||
assert result is thread
|
||||
message.channel.send.assert_awaited_once_with("🧵 Thread created by Hermes: **Hello**")
|
||||
seed_message.create_thread.assert_awaited_once_with(
|
||||
name="Hello",
|
||||
auto_archive_duration=1440,
|
||||
reason="Auto-threaded from mention by Jezza",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_create_thread_returns_none_when_direct_and_fallback_fail(adapter):
|
||||
message = SimpleNamespace(
|
||||
content="Hello",
|
||||
create_thread=AsyncMock(side_effect=RuntimeError("no perms")),
|
||||
channel=SimpleNamespace(send=AsyncMock(side_effect=RuntimeError("send failed"))),
|
||||
author=SimpleNamespace(display_name="Jezza"),
|
||||
)
|
||||
|
||||
result = await adapter._auto_create_thread(message)
|
||||
@@ -682,19 +599,12 @@ def test_discord_auto_thread_config_bridge(monkeypatch, tmp_path):
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /skill command registration (flat + autocomplete)
|
||||
# /skill group registration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_register_skill_command_is_flat_not_nested(adapter):
|
||||
"""_register_skill_group should register a single flat ``/skill`` command.
|
||||
|
||||
The older layout nested categories as subcommand groups under ``/skill``.
|
||||
That registered as one giant command whose serialized payload exceeded
|
||||
Discord's 8KB per-command limit with the default skill catalog. The
|
||||
flat layout sidesteps the limit — autocomplete options are fetched
|
||||
dynamically by Discord and don't count against the registration budget.
|
||||
"""
|
||||
def test_register_skill_group_creates_group(adapter):
|
||||
"""_register_skill_group should register a '/skill' Group on the tree."""
|
||||
mock_categories = {
|
||||
"creative": [
|
||||
("ascii-art", "Generate ASCII art", "/ascii-art"),
|
||||
@@ -715,17 +625,22 @@ def test_register_skill_command_is_flat_not_nested(adapter):
|
||||
adapter._register_slash_commands()
|
||||
|
||||
tree = adapter._client.tree
|
||||
assert "skill" in tree.commands, "Expected /skill command to be registered"
|
||||
skill_cmd = tree.commands["skill"]
|
||||
assert skill_cmd.name == "skill"
|
||||
# Flat command — NOT a Group — so it has no _children of category subgroups
|
||||
assert not hasattr(skill_cmd, "_children") or not getattr(skill_cmd, "_children", {}), (
|
||||
"Flat /skill command should not have subcommand children"
|
||||
)
|
||||
assert "skill" in tree.commands, "Expected /skill group to be registered"
|
||||
skill_group = tree.commands["skill"]
|
||||
assert skill_group.name == "skill"
|
||||
# Should have 2 category subgroups + 1 uncategorized subcommand
|
||||
children = skill_group._children
|
||||
assert "creative" in children
|
||||
assert "media" in children
|
||||
assert "dogfood" in children
|
||||
# Category groups should have their skills
|
||||
assert "ascii-art" in children["creative"]._children
|
||||
assert "excalidraw" in children["creative"]._children
|
||||
assert "gif-search" in children["media"]._children
|
||||
|
||||
|
||||
def test_register_skill_command_empty_skills_no_command(adapter):
|
||||
"""No /skill command should be registered when there are zero skills."""
|
||||
def test_register_skill_group_empty_skills_no_group(adapter):
|
||||
"""No /skill group should be added when there are zero skills."""
|
||||
with patch(
|
||||
"hermes_cli.commands.discord_skill_commands_by_category",
|
||||
return_value=({}, [], 0),
|
||||
@@ -736,134 +651,13 @@ def test_register_skill_command_empty_skills_no_command(adapter):
|
||||
assert "skill" not in tree.commands
|
||||
|
||||
|
||||
def test_register_skill_command_callback_dispatches_by_name(adapter):
|
||||
"""The /skill callback should look up the skill by ``name`` and
|
||||
dispatch via ``_run_simple_slash`` with the real command key.
|
||||
"""
|
||||
def test_register_skill_group_handler_dispatches_command(adapter):
|
||||
"""Skill subcommand handlers should dispatch the correct /cmd-key text."""
|
||||
mock_categories = {
|
||||
"media": [
|
||||
("gif-search", "Search for GIFs", "/gif-search"),
|
||||
],
|
||||
}
|
||||
mock_uncategorized = [
|
||||
("dogfood", "QA testing", "/dogfood"),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"hermes_cli.commands.discord_skill_commands_by_category",
|
||||
return_value=(mock_categories, mock_uncategorized, 0),
|
||||
):
|
||||
adapter._register_slash_commands()
|
||||
|
||||
skill_cmd = adapter._client.tree.commands["skill"]
|
||||
assert skill_cmd.callback is not None
|
||||
|
||||
# Stub out _run_simple_slash so we can verify the dispatched text.
|
||||
dispatched: list[str] = []
|
||||
|
||||
async def fake_run(_interaction, text):
|
||||
dispatched.append(text)
|
||||
|
||||
adapter._run_simple_slash = fake_run
|
||||
|
||||
import asyncio
|
||||
|
||||
fake_interaction = SimpleNamespace()
|
||||
# gif-search → /gif-search with no args
|
||||
asyncio.run(skill_cmd.callback(fake_interaction, name="gif-search"))
|
||||
# dogfood with args
|
||||
asyncio.run(skill_cmd.callback(fake_interaction, name="dogfood", args="my test"))
|
||||
|
||||
assert dispatched == ["/gif-search", "/dogfood my test"]
|
||||
|
||||
|
||||
def test_register_skill_command_handles_unknown_skill_gracefully(adapter):
|
||||
"""Passing a name that isn't a registered skill should respond with
|
||||
an ephemeral error message, NOT crash the callback.
|
||||
"""
|
||||
with patch(
|
||||
"hermes_cli.commands.discord_skill_commands_by_category",
|
||||
return_value=({"media": [("gif-search", "GIFs", "/gif-search")]}, [], 0),
|
||||
):
|
||||
adapter._register_slash_commands()
|
||||
|
||||
skill_cmd = adapter._client.tree.commands["skill"]
|
||||
|
||||
sent: list[dict] = []
|
||||
|
||||
async def fake_send(text, ephemeral=False):
|
||||
sent.append({"text": text, "ephemeral": ephemeral})
|
||||
|
||||
interaction = SimpleNamespace(
|
||||
response=SimpleNamespace(send_message=fake_send),
|
||||
)
|
||||
|
||||
import asyncio
|
||||
asyncio.run(skill_cmd.callback(interaction, name="does-not-exist"))
|
||||
|
||||
assert len(sent) == 1
|
||||
assert "Unknown skill" in sent[0]["text"]
|
||||
assert "does-not-exist" in sent[0]["text"]
|
||||
assert sent[0]["ephemeral"] is True
|
||||
|
||||
|
||||
def test_register_skill_command_payload_fits_discord_8kb_limit(adapter):
|
||||
"""The /skill command registration payload must stay under Discord's
|
||||
~8000-byte per-command limit even with a large skill catalog.
|
||||
|
||||
This is the regression guard for #11321 / #10259. Simulates 500 skills
|
||||
(20 categories × 25 — the hard cap per category in the collector) and
|
||||
confirms the serialized command still fits. Autocomplete options are
|
||||
not part of this payload, so the budget is essentially constant.
|
||||
"""
|
||||
import json
|
||||
|
||||
# Simulate the largest catalog the collector will ever produce:
|
||||
# 20 categories × 25 skills each, with verbose 100-char descriptions.
|
||||
large_categories: dict[str, list[tuple[str, str, str]]] = {}
|
||||
long_desc = "A verbose description padded to approximately 100 chars " + "." * 42
|
||||
for i in range(20):
|
||||
cat = f"cat{i:02d}"
|
||||
large_categories[cat] = [
|
||||
(f"skill-{i:02d}-{j:02d}", long_desc, f"/skill-{i:02d}-{j:02d}")
|
||||
for j in range(25)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"hermes_cli.commands.discord_skill_commands_by_category",
|
||||
return_value=(large_categories, [], 0),
|
||||
):
|
||||
adapter._register_slash_commands()
|
||||
|
||||
skill_cmd = adapter._client.tree.commands["skill"]
|
||||
# Approximate the serialized registration payload (name + description only).
|
||||
# Autocomplete options are NOT registered — they're fetched dynamically.
|
||||
payload = json.dumps({
|
||||
"name": skill_cmd.name,
|
||||
"description": skill_cmd.description,
|
||||
"options": [
|
||||
{"name": "name", "description": "Which skill to run", "type": 3, "required": True},
|
||||
{"name": "args", "description": "Optional arguments for the skill", "type": 3, "required": False},
|
||||
],
|
||||
})
|
||||
assert len(payload) < 500, (
|
||||
f"Flat /skill command payload is ~{len(payload)} bytes — the whole "
|
||||
f"point of this design is that it stays small regardless of skill count"
|
||||
)
|
||||
|
||||
|
||||
def test_register_skill_command_autocomplete_filters_by_name_and_description(adapter):
|
||||
"""The autocomplete callback should match on both skill name and
|
||||
description so the user can search by either.
|
||||
"""
|
||||
mock_categories = {
|
||||
"ocr": [
|
||||
("ocr-and-documents", "Extract text from PDFs and scanned documents", "/ocr-and-documents"),
|
||||
],
|
||||
"media": [
|
||||
("gif-search", "Search and download GIFs from Tenor", "/gif-search"),
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"hermes_cli.commands.discord_skill_commands_by_category",
|
||||
@@ -871,15 +665,10 @@ def test_register_skill_command_autocomplete_filters_by_name_and_description(ada
|
||||
):
|
||||
adapter._register_slash_commands()
|
||||
|
||||
skill_cmd = adapter._client.tree.commands["skill"]
|
||||
# The callback has been wrapped with @autocomplete(name=...) — in our mock
|
||||
# the decorator is pass-through, so we inspect the closed-over list by
|
||||
# invoking the registered autocomplete function directly through the
|
||||
# test API. Since the mock doesn't preserve the autocomplete binding,
|
||||
# we re-derive the filter by building the same entries list.
|
||||
#
|
||||
# What we CAN verify at this layer: the callback dispatches correctly
|
||||
# (covered in other tests). The autocomplete filter itself is exercised
|
||||
# via direct function call in the real-discord integration path.
|
||||
assert skill_cmd.callback is not None
|
||||
skill_group = adapter._client.tree.commands["skill"]
|
||||
media_group = skill_group._children["media"]
|
||||
gif_cmd = media_group._children["gif-search"]
|
||||
assert gif_cmd.callback is not None
|
||||
# The callback name should reflect the skill
|
||||
assert "gif_search" in gif_cmd.callback.__name__
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tests for the QQ Bot platform adapter."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@@ -150,47 +149,6 @@ class TestIsVoiceContentType:
|
||||
assert self._fn("", "recording.amr") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Voice attachment SSRF protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestVoiceAttachmentSSRFProtection:
|
||||
def _make_adapter(self, **extra):
|
||||
from gateway.platforms.qqbot import QQAdapter
|
||||
return QQAdapter(_make_config(**extra))
|
||||
|
||||
def test_stt_blocks_unsafe_download_url(self):
|
||||
adapter = self._make_adapter(app_id="a", client_secret="b")
|
||||
adapter._http_client = mock.AsyncMock()
|
||||
|
||||
with mock.patch("tools.url_safety.is_safe_url", return_value=False):
|
||||
transcript = asyncio.run(
|
||||
adapter._stt_voice_attachment(
|
||||
"http://127.0.0.1/voice.silk",
|
||||
"audio/silk",
|
||||
"voice.silk",
|
||||
)
|
||||
)
|
||||
|
||||
assert transcript is None
|
||||
adapter._http_client.get.assert_not_called()
|
||||
|
||||
def test_connect_uses_redirect_guard_hook(self):
|
||||
from gateway.platforms.qqbot import QQAdapter, _ssrf_redirect_guard
|
||||
|
||||
client = mock.AsyncMock()
|
||||
with mock.patch("gateway.platforms.qqbot.httpx.AsyncClient", return_value=client) as async_client_cls:
|
||||
adapter = QQAdapter(_make_config(app_id="a", client_secret="b"))
|
||||
adapter._ensure_token = mock.AsyncMock(side_effect=RuntimeError("stop after client creation"))
|
||||
|
||||
connected = asyncio.run(adapter.connect())
|
||||
|
||||
assert connected is False
|
||||
assert async_client_cls.call_count == 1
|
||||
kwargs = async_client_cls.call_args.kwargs
|
||||
assert kwargs.get("follow_redirects") is True
|
||||
assert kwargs.get("event_hooks", {}).get("response") == [_ssrf_redirect_guard]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _strip_at_mention
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -500,85 +458,3 @@ class TestBuildTextBody:
|
||||
adapter = self._make_adapter(app_id="a", client_secret="b", markdown_support=False)
|
||||
body = adapter._build_text_body("reply text", reply_to="msg_123")
|
||||
assert body.get("message_reference", {}).get("message_id") == "msg_123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _wait_for_reconnection / send reconnection wait
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWaitForReconnection:
|
||||
"""Test that send() waits for reconnection instead of silently dropping."""
|
||||
|
||||
def _make_adapter(self, **extra):
|
||||
from gateway.platforms.qqbot import QQAdapter
|
||||
return QQAdapter(_make_config(**extra))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_waits_and_succeeds_on_reconnect(self):
|
||||
"""send() should wait for reconnection and then deliver the message."""
|
||||
adapter = self._make_adapter(app_id="a", client_secret="b")
|
||||
# Initially disconnected
|
||||
adapter._running = False
|
||||
adapter._http_client = mock.MagicMock()
|
||||
|
||||
# Simulate reconnection after 0.3s (faster than real interval)
|
||||
async def fake_api_request(*args, **kwargs):
|
||||
return {"id": "msg_123"}
|
||||
|
||||
adapter._api_request = fake_api_request
|
||||
adapter._ensure_token = mock.AsyncMock()
|
||||
adapter._RECONNECT_POLL_INTERVAL = 0.1
|
||||
adapter._RECONNECT_WAIT_SECONDS = 5.0
|
||||
|
||||
# Schedule reconnection after a short delay
|
||||
async def reconnect_after_delay():
|
||||
await asyncio.sleep(0.3)
|
||||
adapter._running = True
|
||||
|
||||
asyncio.get_event_loop().create_task(reconnect_after_delay())
|
||||
|
||||
result = await adapter.send("test_openid", "Hello, world!")
|
||||
assert result.success
|
||||
assert result.message_id == "msg_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_retryable_after_timeout(self):
|
||||
"""send() should return retryable=True if reconnection takes too long."""
|
||||
adapter = self._make_adapter(app_id="a", client_secret="b")
|
||||
adapter._running = False
|
||||
adapter._RECONNECT_POLL_INTERVAL = 0.05
|
||||
adapter._RECONNECT_WAIT_SECONDS = 0.2
|
||||
|
||||
result = await adapter.send("test_openid", "Hello, world!")
|
||||
assert not result.success
|
||||
assert result.retryable is True
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_succeeds_immediately_when_connected(self):
|
||||
"""send() should not wait when already connected."""
|
||||
adapter = self._make_adapter(app_id="a", client_secret="b")
|
||||
adapter._running = True
|
||||
adapter._http_client = mock.MagicMock()
|
||||
|
||||
async def fake_api_request(*args, **kwargs):
|
||||
return {"id": "msg_immediate"}
|
||||
|
||||
adapter._api_request = fake_api_request
|
||||
|
||||
result = await adapter.send("test_openid", "Hello!")
|
||||
assert result.success
|
||||
assert result.message_id == "msg_immediate"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_waits_for_reconnect(self):
|
||||
"""_send_media should also wait for reconnection."""
|
||||
adapter = self._make_adapter(app_id="a", client_secret="b")
|
||||
adapter._running = False
|
||||
adapter._RECONNECT_POLL_INTERVAL = 0.05
|
||||
adapter._RECONNECT_WAIT_SECONDS = 0.2
|
||||
|
||||
result = await adapter._send_media("test_openid", "http://example.com/img.jpg", 1, "image")
|
||||
assert not result.success
|
||||
assert result.retryable is True
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@@ -21,7 +21,6 @@ def _clear_auth_env(monkeypatch) -> None:
|
||||
"MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS",
|
||||
"DINGTALK_ALLOWED_USERS", "FEISHU_ALLOWED_USERS", "WECOM_ALLOWED_USERS",
|
||||
"QQ_ALLOWED_USERS", "QQ_GROUP_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS",
|
||||
"TELEGRAM_ALLOW_ALL_USERS",
|
||||
"DISCORD_ALLOW_ALL_USERS",
|
||||
@@ -33,7 +32,6 @@ def _clear_auth_env(monkeypatch) -> None:
|
||||
"MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS",
|
||||
"DINGTALK_ALLOW_ALL_USERS", "FEISHU_ALLOW_ALL_USERS", "WECOM_ALLOW_ALL_USERS",
|
||||
"QQ_ALLOW_ALL_USERS",
|
||||
"GATEWAY_ALLOW_ALL_USERS",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
@@ -132,46 +130,6 @@ def test_star_wildcard_works_for_any_platform(monkeypatch):
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_qq_group_allowlist_authorizes_group_chat_without_user_allowlist(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("QQ_GROUP_ALLOWED_USERS", "group-openid-1")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.QQBOT,
|
||||
GatewayConfig(platforms={Platform.QQBOT: PlatformConfig(enabled=True)}),
|
||||
)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.QQBOT,
|
||||
user_id="member-openid-999",
|
||||
chat_id="group-openid-1",
|
||||
user_name="tester",
|
||||
chat_type="group",
|
||||
)
|
||||
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_qq_group_allowlist_does_not_authorize_other_groups(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("QQ_GROUP_ALLOWED_USERS", "group-openid-1")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.QQBOT,
|
||||
GatewayConfig(platforms={Platform.QQBOT: PlatformConfig(enabled=True)}),
|
||||
)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.QQBOT,
|
||||
user_id="member-openid-999",
|
||||
chat_id="group-openid-2",
|
||||
user_name="tester",
|
||||
chat_type="group",
|
||||
)
|
||||
|
||||
assert runner._is_user_authorized(source) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_dm_pairs_by_default(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
|
||||
+24
-262
@@ -1,15 +1,12 @@
|
||||
"""Tests for the Weixin platform adapter."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides
|
||||
from gateway.platforms.base import SendResult
|
||||
from gateway.platforms import weixin
|
||||
from gateway.platforms.weixin import ContextTokenStore, WeixinAdapter
|
||||
from tools.send_message_tool import _parse_target_ref, _send_to_platform
|
||||
@@ -26,14 +23,17 @@ def _make_adapter() -> WeixinAdapter:
|
||||
|
||||
|
||||
class TestWeixinFormatting:
|
||||
def test_format_message_preserves_markdown(self):
|
||||
def test_format_message_preserves_markdown_and_rewrites_headers(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = "# Title\n\n## Plan\n\nUse **bold** and [docs](https://example.com)."
|
||||
|
||||
assert adapter.format_message(content) == content
|
||||
assert (
|
||||
adapter.format_message(content)
|
||||
== "【Title】\n\n**Plan**\n\nUse **bold** and docs (https://example.com)."
|
||||
)
|
||||
|
||||
def test_format_message_preserves_markdown_tables(self):
|
||||
def test_format_message_rewrites_markdown_tables(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = (
|
||||
@@ -43,14 +43,19 @@ class TestWeixinFormatting:
|
||||
"| Retries | 3 |\n"
|
||||
)
|
||||
|
||||
assert adapter.format_message(content) == content.strip()
|
||||
assert adapter.format_message(content) == (
|
||||
"- Setting: Timeout\n"
|
||||
" Value: 30s\n"
|
||||
"- Setting: Retries\n"
|
||||
" Value: 3"
|
||||
)
|
||||
|
||||
def test_format_message_preserves_fenced_code_blocks(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = "## Snippet\n\n```python\nprint('hi')\n```"
|
||||
|
||||
assert adapter.format_message(content) == content
|
||||
assert adapter.format_message(content) == "**Snippet**\n\n```python\nprint('hi')\n```"
|
||||
|
||||
def test_format_message_returns_empty_string_for_none(self):
|
||||
adapter = _make_adapter()
|
||||
@@ -96,7 +101,7 @@ class TestWeixinChunking:
|
||||
content = adapter.format_message("## 结论\n这是正文")
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert chunks == ["## 结论\n这是正文"]
|
||||
assert chunks == ["**结论**\n这是正文"]
|
||||
|
||||
def test_split_text_keeps_short_reformatted_table_in_single_chunk(self):
|
||||
adapter = _make_adapter()
|
||||
@@ -313,7 +318,6 @@ class TestWeixinChunkDelivery:
|
||||
def _connected_adapter(self) -> WeixinAdapter:
|
||||
adapter = _make_adapter()
|
||||
adapter._session = object()
|
||||
adapter._send_session = adapter._session
|
||||
adapter._token = "test-token"
|
||||
adapter._base_url = "https://weixin.example.com"
|
||||
adapter._token_store.get = lambda account_id, chat_id: "ctx-token"
|
||||
@@ -359,115 +363,6 @@ class TestWeixinChunkDelivery:
|
||||
assert first_try["client_id"] == retry["client_id"]
|
||||
|
||||
|
||||
class TestWeixinOutboundMedia:
|
||||
def test_send_image_file_accepts_keyword_image_path(self):
|
||||
adapter = _make_adapter()
|
||||
expected = SendResult(success=True, message_id="msg-1")
|
||||
adapter.send_document = AsyncMock(return_value=expected)
|
||||
|
||||
result = asyncio.run(
|
||||
adapter.send_image_file(
|
||||
chat_id="wxid_test123",
|
||||
image_path="/tmp/demo.png",
|
||||
caption="截图说明",
|
||||
reply_to="reply-1",
|
||||
metadata={"thread_id": "t-1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
adapter.send_document.assert_awaited_once_with(
|
||||
chat_id="wxid_test123",
|
||||
file_path="/tmp/demo.png",
|
||||
caption="截图说明",
|
||||
metadata={"thread_id": "t-1"},
|
||||
)
|
||||
|
||||
def test_send_document_accepts_keyword_file_path(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._session = object()
|
||||
adapter._send_session = adapter._session
|
||||
adapter._token = "test-token"
|
||||
adapter._send_file = AsyncMock(return_value="msg-2")
|
||||
|
||||
result = asyncio.run(
|
||||
adapter.send_document(
|
||||
chat_id="wxid_test123",
|
||||
file_path="/tmp/report.pdf",
|
||||
caption="报告请看",
|
||||
file_name="renamed.pdf",
|
||||
reply_to="reply-1",
|
||||
metadata={"thread_id": "t-1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "msg-2"
|
||||
adapter._send_file.assert_awaited_once_with("wxid_test123", "/tmp/report.pdf", "报告请看")
|
||||
|
||||
def test_send_file_uses_post_for_upload_full_url_and_hex_encoded_aes_key(self, tmp_path):
|
||||
class _UploadResponse:
|
||||
def __init__(self):
|
||||
self.status = 200
|
||||
self.headers = {"x-encrypted-param": "enc-param"}
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def read(self):
|
||||
return b""
|
||||
|
||||
async def text(self):
|
||||
return ""
|
||||
|
||||
class _RecordingSession:
|
||||
def __init__(self):
|
||||
self.post_calls = []
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.post_calls.append((url, kwargs))
|
||||
return _UploadResponse()
|
||||
|
||||
def put(self, *_args, **_kwargs):
|
||||
raise AssertionError("upload_full_url branch should use POST")
|
||||
|
||||
image_path = tmp_path / "demo.png"
|
||||
image_path.write_bytes(b"fake-png-bytes")
|
||||
|
||||
adapter = _make_adapter()
|
||||
session = _RecordingSession()
|
||||
adapter._session = session
|
||||
adapter._send_session = session
|
||||
adapter._token = "test-token"
|
||||
adapter._base_url = "https://weixin.example.com"
|
||||
adapter._cdn_base_url = "https://cdn.example.com/c2c"
|
||||
adapter._token_store.get = lambda account_id, chat_id: None
|
||||
|
||||
aes_key = bytes(range(16))
|
||||
expected_aes_key = base64.b64encode(aes_key.hex().encode("ascii")).decode("ascii")
|
||||
|
||||
with patch("gateway.platforms.weixin._get_upload_url", new=AsyncMock(return_value={"upload_full_url": "https://upload.example.com/media"})), \
|
||||
patch("gateway.platforms.weixin._api_post", new_callable=AsyncMock) as api_post_mock, \
|
||||
patch("gateway.platforms.weixin.secrets.token_hex", return_value="filekey-123"), \
|
||||
patch("gateway.platforms.weixin.secrets.token_bytes", return_value=aes_key):
|
||||
message_id = asyncio.run(adapter._send_file("wxid_test123", str(image_path), ""))
|
||||
|
||||
assert message_id.startswith("hermes-weixin-")
|
||||
assert len(session.post_calls) == 1
|
||||
upload_url, upload_kwargs = session.post_calls[0]
|
||||
assert upload_url == "https://upload.example.com/media"
|
||||
assert upload_kwargs["headers"] == {"Content-Type": "application/octet-stream"}
|
||||
assert upload_kwargs["data"]
|
||||
assert upload_kwargs["timeout"].total == 120
|
||||
payload = api_post_mock.await_args.kwargs["payload"]
|
||||
media = payload["msg"]["item_list"][0]["image_item"]["media"]
|
||||
assert media["encrypt_query_param"] == "enc-param"
|
||||
assert media["aes_key"] == expected_aes_key
|
||||
|
||||
|
||||
class TestWeixinRemoteMediaSafety:
|
||||
def test_download_remote_media_blocks_unsafe_urls(self):
|
||||
adapter = _make_adapter()
|
||||
@@ -482,13 +377,16 @@ class TestWeixinRemoteMediaSafety:
|
||||
|
||||
|
||||
class TestWeixinMarkdownLinks:
|
||||
"""Markdown links should be preserved so WeChat can render them natively."""
|
||||
"""Markdown links should be converted to plaintext since WeChat can't render them."""
|
||||
|
||||
def test_format_message_preserves_markdown_links(self):
|
||||
def test_format_message_converts_markdown_links_to_plain_text(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = "Check [the docs](https://example.com) and [GitHub](https://github.com) for details"
|
||||
assert adapter.format_message(content) == content
|
||||
assert (
|
||||
adapter.format_message(content)
|
||||
== "Check the docs (https://example.com) and GitHub (https://github.com) for details"
|
||||
)
|
||||
|
||||
def test_format_message_preserves_links_inside_code_blocks(self):
|
||||
adapter = _make_adapter()
|
||||
@@ -532,7 +430,6 @@ class TestWeixinBlankMessagePrevention:
|
||||
def test_send_empty_content_does_not_call_send_message(self, send_message_mock):
|
||||
adapter = _make_adapter()
|
||||
adapter._session = object()
|
||||
adapter._send_session = adapter._session
|
||||
adapter._token = "test-token"
|
||||
adapter._base_url = "https://weixin.example.com"
|
||||
adapter._token_store.get = lambda account_id, chat_id: "ctx-token"
|
||||
@@ -603,10 +500,10 @@ class TestWeixinMediaBuilder:
|
||||
)
|
||||
assert item["video_item"]["video_md5"] == "deadbeef"
|
||||
|
||||
def test_voice_builder_for_audio_files_uses_file_attachment_type(self):
|
||||
def test_voice_builder_for_audio_files(self):
|
||||
adapter = _make_adapter()
|
||||
media_type, builder = adapter._outbound_media_builder("note.mp3")
|
||||
assert media_type == weixin.MEDIA_FILE
|
||||
assert media_type == weixin.MEDIA_VOICE
|
||||
|
||||
item = builder(
|
||||
encrypt_query_param="eq",
|
||||
@@ -616,145 +513,10 @@ class TestWeixinMediaBuilder:
|
||||
filename="note.mp3",
|
||||
rawfilemd5="abc",
|
||||
)
|
||||
assert item["type"] == weixin.ITEM_FILE
|
||||
assert item["file_item"]["file_name"] == "note.mp3"
|
||||
assert item["type"] == weixin.ITEM_VOICE
|
||||
assert "voice_item" in item
|
||||
|
||||
def test_voice_builder_for_silk_files(self):
|
||||
adapter = _make_adapter()
|
||||
media_type, builder = adapter._outbound_media_builder("recording.silk")
|
||||
assert media_type == weixin.MEDIA_VOICE
|
||||
|
||||
|
||||
class TestWeixinSendImageFileParameterName:
|
||||
"""Regression test for send_image_file parameter name mismatch.
|
||||
|
||||
The gateway calls send_image_file(chat_id=..., image_path=...) but the
|
||||
WeixinAdapter previously used 'path' as the parameter name, causing
|
||||
image sending to fail. This test ensures the interface stays correct.
|
||||
"""
|
||||
|
||||
@patch.object(WeixinAdapter, "send_document", new_callable=AsyncMock)
|
||||
def test_send_image_file_uses_image_path_parameter(self, send_document_mock):
|
||||
"""Verify send_image_file accepts image_path and forwards to send_document."""
|
||||
adapter = _make_adapter()
|
||||
adapter._session = object()
|
||||
adapter._send_session = adapter._session
|
||||
adapter._token = "test-token"
|
||||
|
||||
send_document_mock.return_value = weixin.SendResult(success=True, message_id="test-id")
|
||||
|
||||
# This is the call pattern used by gateway/run.py extract_media
|
||||
result = asyncio.run(
|
||||
adapter.send_image_file(
|
||||
chat_id="wxid_test123",
|
||||
image_path="/tmp/test_image.png",
|
||||
caption="Test caption",
|
||||
metadata={"thread_id": "thread-123"},
|
||||
)
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
send_document_mock.assert_awaited_once_with(
|
||||
chat_id="wxid_test123",
|
||||
file_path="/tmp/test_image.png",
|
||||
caption="Test caption",
|
||||
metadata={"thread_id": "thread-123"},
|
||||
)
|
||||
|
||||
@patch.object(WeixinAdapter, "send_document", new_callable=AsyncMock)
|
||||
def test_send_image_file_works_without_optional_params(self, send_document_mock):
|
||||
"""Verify send_image_file works with minimal required params."""
|
||||
adapter = _make_adapter()
|
||||
adapter._session = object()
|
||||
adapter._send_session = adapter._session
|
||||
adapter._token = "test-token"
|
||||
|
||||
send_document_mock.return_value = weixin.SendResult(success=True, message_id="test-id")
|
||||
|
||||
result = asyncio.run(
|
||||
adapter.send_image_file(
|
||||
chat_id="wxid_test123",
|
||||
image_path="/tmp/test_image.jpg",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
send_document_mock.assert_awaited_once_with(
|
||||
chat_id="wxid_test123",
|
||||
file_path="/tmp/test_image.jpg",
|
||||
caption=None,
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
|
||||
class TestWeixinVoiceSending:
|
||||
def _connected_adapter(self) -> WeixinAdapter:
|
||||
adapter = _make_adapter()
|
||||
adapter._session = object()
|
||||
adapter._send_session = adapter._session
|
||||
adapter._token = "test-token"
|
||||
adapter._base_url = "https://weixin.example.com"
|
||||
adapter._token_store.get = lambda account_id, chat_id: "ctx-token"
|
||||
return adapter
|
||||
|
||||
@patch.object(WeixinAdapter, "_send_file", new_callable=AsyncMock)
|
||||
def test_send_voice_downgrades_to_document_attachment(self, send_file_mock, tmp_path):
|
||||
adapter = self._connected_adapter()
|
||||
source = tmp_path / "voice.ogg"
|
||||
source.write_bytes(b"ogg")
|
||||
send_file_mock.return_value = "msg-1"
|
||||
|
||||
result = asyncio.run(adapter.send_voice("wxid_test123", str(source)))
|
||||
|
||||
assert result.success is True
|
||||
send_file_mock.assert_awaited_once_with(
|
||||
"wxid_test123",
|
||||
str(source),
|
||||
"[voice message as attachment]",
|
||||
force_file_attachment=True,
|
||||
)
|
||||
|
||||
def test_voice_builder_for_silk_files_can_be_forced_to_file_attachment(self):
|
||||
adapter = _make_adapter()
|
||||
media_type, builder = adapter._outbound_media_builder(
|
||||
"recording.silk",
|
||||
force_file_attachment=True,
|
||||
)
|
||||
assert media_type == weixin.MEDIA_FILE
|
||||
|
||||
item = builder(
|
||||
encrypt_query_param="eq",
|
||||
aes_key_for_api="fakekey",
|
||||
ciphertext_size=512,
|
||||
plaintext_size=500,
|
||||
filename="recording.silk",
|
||||
rawfilemd5="abc",
|
||||
)
|
||||
assert item["type"] == weixin.ITEM_FILE
|
||||
assert item["file_item"]["file_name"] == "recording.silk"
|
||||
|
||||
@patch.object(weixin, "_api_post", new_callable=AsyncMock)
|
||||
@patch.object(weixin, "_upload_ciphertext", new_callable=AsyncMock)
|
||||
@patch.object(weixin, "_get_upload_url", new_callable=AsyncMock)
|
||||
def test_send_file_sets_voice_metadata_for_silk_payload(
|
||||
self,
|
||||
get_upload_url_mock,
|
||||
upload_ciphertext_mock,
|
||||
api_post_mock,
|
||||
tmp_path,
|
||||
):
|
||||
adapter = self._connected_adapter()
|
||||
silk = tmp_path / "voice.silk"
|
||||
silk.write_bytes(b"\x02#!SILK_V3\x01\x00")
|
||||
get_upload_url_mock.return_value = {"upload_full_url": "https://cdn.example.com/upload"}
|
||||
upload_ciphertext_mock.return_value = "enc-q"
|
||||
api_post_mock.return_value = {"success": True}
|
||||
|
||||
asyncio.run(adapter._send_file("wxid_test123", str(silk), ""))
|
||||
|
||||
payload = api_post_mock.await_args.kwargs["payload"]
|
||||
voice_item = payload["msg"]["item_list"][0]["voice_item"]
|
||||
assert voice_item.get("playtime", 0) == 0
|
||||
assert voice_item["encode_type"] == 6
|
||||
assert voice_item["sample_rate"] == 24000
|
||||
assert voice_item["bits_per_sample"] == 16
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
"""Tests for API-key provider support (z.ai/GLM, Kimi, MiniMax, AI Gateway)."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure dotenv doesn't interfere
|
||||
if "dotenv" not in sys.modules:
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
sys.modules["dotenv"] = fake_dotenv
|
||||
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY,
|
||||
ProviderConfig,
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
"""Tests for Arcee AI provider support — standard direct API provider."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
if "dotenv" not in sys.modules:
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
sys.modules["dotenv"] = fake_dotenv
|
||||
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY,
|
||||
resolve_provider,
|
||||
|
||||
@@ -703,231 +703,3 @@ def test_auth_remove_claude_code_suppresses_reseed(tmp_path, monkeypatch):
|
||||
suppressed = updated.get("suppressed_sources", {})
|
||||
assert "anthropic" in suppressed
|
||||
assert "claude_code" in suppressed["anthropic"]
|
||||
|
||||
|
||||
def test_unsuppress_credential_source_clears_marker(tmp_path, monkeypatch):
|
||||
"""unsuppress_credential_source() removes a previously-set marker."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {"version": 1})
|
||||
|
||||
from hermes_cli.auth import suppress_credential_source, unsuppress_credential_source, is_source_suppressed
|
||||
|
||||
suppress_credential_source("openai-codex", "device_code")
|
||||
assert is_source_suppressed("openai-codex", "device_code") is True
|
||||
|
||||
cleared = unsuppress_credential_source("openai-codex", "device_code")
|
||||
assert cleared is True
|
||||
assert is_source_suppressed("openai-codex", "device_code") is False
|
||||
|
||||
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
# Empty suppressed_sources dict should be cleaned up entirely
|
||||
assert "suppressed_sources" not in payload
|
||||
|
||||
|
||||
def test_unsuppress_credential_source_returns_false_when_absent(tmp_path, monkeypatch):
|
||||
"""unsuppress_credential_source() returns False if no marker exists."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {"version": 1})
|
||||
|
||||
from hermes_cli.auth import unsuppress_credential_source
|
||||
|
||||
assert unsuppress_credential_source("openai-codex", "device_code") is False
|
||||
assert unsuppress_credential_source("nonexistent", "whatever") is False
|
||||
|
||||
|
||||
def test_unsuppress_credential_source_preserves_other_markers(tmp_path, monkeypatch):
|
||||
"""Clearing one marker must not affect unrelated markers."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {"version": 1})
|
||||
|
||||
from hermes_cli.auth import (
|
||||
suppress_credential_source,
|
||||
unsuppress_credential_source,
|
||||
is_source_suppressed,
|
||||
)
|
||||
|
||||
suppress_credential_source("openai-codex", "device_code")
|
||||
suppress_credential_source("anthropic", "claude_code")
|
||||
|
||||
assert unsuppress_credential_source("openai-codex", "device_code") is True
|
||||
assert is_source_suppressed("anthropic", "claude_code") is True
|
||||
|
||||
|
||||
def test_auth_remove_codex_device_code_suppresses_reseed(tmp_path, monkeypatch):
|
||||
"""Removing an auto-seeded openai-codex credential must mark the source as suppressed."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_singletons",
|
||||
lambda provider, entries: (False, {"device_code"}),
|
||||
)
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
auth_store = {
|
||||
"version": 1,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {
|
||||
"access_token": "acc-1",
|
||||
"refresh_token": "ref-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
"credential_pool": {
|
||||
"openai-codex": [{
|
||||
"id": "cx1",
|
||||
"label": "codex-auto",
|
||||
"auth_type": "oauth",
|
||||
"priority": 0,
|
||||
"source": "device_code",
|
||||
"access_token": "acc-1",
|
||||
"refresh_token": "ref-1",
|
||||
}]
|
||||
},
|
||||
}
|
||||
(hermes_home / "auth.json").write_text(json.dumps(auth_store))
|
||||
|
||||
from types import SimpleNamespace
|
||||
from hermes_cli.auth_commands import auth_remove_command
|
||||
|
||||
auth_remove_command(SimpleNamespace(provider="openai-codex", target="1"))
|
||||
|
||||
updated = json.loads((hermes_home / "auth.json").read_text())
|
||||
suppressed = updated.get("suppressed_sources", {})
|
||||
assert "openai-codex" in suppressed
|
||||
assert "device_code" in suppressed["openai-codex"]
|
||||
# Tokens in providers state should also be cleared
|
||||
assert "openai-codex" not in updated.get("providers", {})
|
||||
|
||||
|
||||
def test_auth_remove_codex_manual_source_suppresses_reseed(tmp_path, monkeypatch):
|
||||
"""Removing a manually-added (`manual:device_code`) openai-codex credential must also suppress."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_singletons",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
auth_store = {
|
||||
"version": 1,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {
|
||||
"access_token": "acc-2",
|
||||
"refresh_token": "ref-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
"credential_pool": {
|
||||
"openai-codex": [{
|
||||
"id": "cx2",
|
||||
"label": "manual-codex",
|
||||
"auth_type": "oauth",
|
||||
"priority": 0,
|
||||
"source": "manual:device_code",
|
||||
"access_token": "acc-2",
|
||||
"refresh_token": "ref-2",
|
||||
}]
|
||||
},
|
||||
}
|
||||
(hermes_home / "auth.json").write_text(json.dumps(auth_store))
|
||||
|
||||
from types import SimpleNamespace
|
||||
from hermes_cli.auth_commands import auth_remove_command
|
||||
|
||||
auth_remove_command(SimpleNamespace(provider="openai-codex", target="1"))
|
||||
|
||||
updated = json.loads((hermes_home / "auth.json").read_text())
|
||||
suppressed = updated.get("suppressed_sources", {})
|
||||
# Critical: manual:device_code source must also trigger the suppression path
|
||||
assert "openai-codex" in suppressed
|
||||
assert "device_code" in suppressed["openai-codex"]
|
||||
assert "openai-codex" not in updated.get("providers", {})
|
||||
|
||||
|
||||
def test_auth_add_codex_clears_suppression_marker(tmp_path, monkeypatch):
|
||||
"""Re-linking codex via `hermes auth add openai-codex` must clear any suppression marker."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_singletons",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Pre-existing suppression (simulating a prior `hermes auth remove`)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {},
|
||||
"suppressed_sources": {"openai-codex": ["device_code"]},
|
||||
}))
|
||||
|
||||
token = _jwt_with_email("codex@example.com")
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._codex_device_code_login",
|
||||
lambda: {
|
||||
"tokens": {
|
||||
"access_token": token,
|
||||
"refresh_token": "refreshed",
|
||||
},
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"last_refresh": "2026-01-01T00:00:00Z",
|
||||
},
|
||||
)
|
||||
|
||||
from hermes_cli.auth_commands import auth_add_command
|
||||
|
||||
class _Args:
|
||||
provider = "openai-codex"
|
||||
auth_type = "oauth"
|
||||
api_key = None
|
||||
label = None
|
||||
|
||||
auth_add_command(_Args())
|
||||
|
||||
payload = json.loads((hermes_home / "auth.json").read_text())
|
||||
# Suppression marker must be cleared
|
||||
assert "openai-codex" not in payload.get("suppressed_sources", {})
|
||||
# New pool entry must be present
|
||||
entries = payload["credential_pool"]["openai-codex"]
|
||||
assert any(e["source"] == "manual:device_code" for e in entries)
|
||||
|
||||
|
||||
def test_seed_from_singletons_respects_codex_suppression(tmp_path, monkeypatch):
|
||||
"""_seed_from_singletons() for openai-codex must skip auto-import when suppressed."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Suppression marker in place
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {},
|
||||
"suppressed_sources": {"openai-codex": ["device_code"]},
|
||||
}))
|
||||
|
||||
# Make _import_codex_cli_tokens return tokens — these would normally trigger
|
||||
# a re-seed, but suppression must skip it.
|
||||
def _fake_import():
|
||||
return {
|
||||
"access_token": "would-be-reimported",
|
||||
"refresh_token": "would-be-reimported",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.auth._import_codex_cli_tokens", _fake_import)
|
||||
|
||||
from agent.credential_pool import _seed_from_singletons
|
||||
|
||||
entries = []
|
||||
changed, active_sources = _seed_from_singletons("openai-codex", entries)
|
||||
|
||||
# With suppression in place: nothing changes, no entries added, no sources
|
||||
assert changed is False
|
||||
assert entries == []
|
||||
assert active_sources == set()
|
||||
|
||||
# Verify the auth store was NOT modified (no auto-import happened)
|
||||
after = json.loads((hermes_home / "auth.json").read_text())
|
||||
assert "openai-codex" not in after.get("providers", {})
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
"""Unit tests for hermes_cli/dingtalk_auth.py (QR device-flow registration)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API layer — _api_post + error mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApiPost:
|
||||
|
||||
def test_raises_on_network_error(self):
|
||||
import requests
|
||||
from hermes_cli.dingtalk_auth import _api_post, RegistrationError
|
||||
|
||||
with patch("hermes_cli.dingtalk_auth.requests.post",
|
||||
side_effect=requests.ConnectionError("nope")):
|
||||
with pytest.raises(RegistrationError, match="Network error"):
|
||||
_api_post("/app/registration/init", {"source": "hermes"})
|
||||
|
||||
def test_raises_on_nonzero_errcode(self):
|
||||
from hermes_cli.dingtalk_auth import _api_post, RegistrationError
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {"errcode": 42, "errmsg": "boom"}
|
||||
|
||||
with patch("hermes_cli.dingtalk_auth.requests.post", return_value=mock_resp):
|
||||
with pytest.raises(RegistrationError, match=r"boom \(errcode=42\)"):
|
||||
_api_post("/app/registration/init", {"source": "hermes"})
|
||||
|
||||
def test_returns_data_on_success(self):
|
||||
from hermes_cli.dingtalk_auth import _api_post
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {"errcode": 0, "nonce": "abc"}
|
||||
|
||||
with patch("hermes_cli.dingtalk_auth.requests.post", return_value=mock_resp):
|
||||
result = _api_post("/app/registration/init", {"source": "hermes"})
|
||||
assert result["nonce"] == "abc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# begin_registration — 2-step nonce → device_code chain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBeginRegistration:
|
||||
|
||||
def test_chains_init_then_begin(self):
|
||||
from hermes_cli.dingtalk_auth import begin_registration
|
||||
|
||||
responses = [
|
||||
{"errcode": 0, "nonce": "nonce123"},
|
||||
{
|
||||
"errcode": 0,
|
||||
"device_code": "dev-xyz",
|
||||
"verification_uri_complete": "https://open-dev.dingtalk.com/openapp/registration/openClaw?user_code=ABCD",
|
||||
"expires_in": 7200,
|
||||
"interval": 2,
|
||||
},
|
||||
]
|
||||
with patch("hermes_cli.dingtalk_auth._api_post", side_effect=responses):
|
||||
result = begin_registration()
|
||||
|
||||
assert result["device_code"] == "dev-xyz"
|
||||
assert "verification_uri_complete" in result
|
||||
assert result["interval"] == 2
|
||||
assert result["expires_in"] == 7200
|
||||
|
||||
def test_missing_nonce_raises(self):
|
||||
from hermes_cli.dingtalk_auth import begin_registration, RegistrationError
|
||||
|
||||
with patch("hermes_cli.dingtalk_auth._api_post",
|
||||
return_value={"errcode": 0, "nonce": ""}):
|
||||
with pytest.raises(RegistrationError, match="missing nonce"):
|
||||
begin_registration()
|
||||
|
||||
def test_missing_device_code_raises(self):
|
||||
from hermes_cli.dingtalk_auth import begin_registration, RegistrationError
|
||||
|
||||
responses = [
|
||||
{"errcode": 0, "nonce": "n1"},
|
||||
{"errcode": 0, "verification_uri_complete": "http://x"}, # no device_code
|
||||
]
|
||||
with patch("hermes_cli.dingtalk_auth._api_post", side_effect=responses):
|
||||
with pytest.raises(RegistrationError, match="missing device_code"):
|
||||
begin_registration()
|
||||
|
||||
def test_missing_verification_uri_raises(self):
|
||||
from hermes_cli.dingtalk_auth import begin_registration, RegistrationError
|
||||
|
||||
responses = [
|
||||
{"errcode": 0, "nonce": "n1"},
|
||||
{"errcode": 0, "device_code": "dev"}, # no verification_uri_complete
|
||||
]
|
||||
with patch("hermes_cli.dingtalk_auth._api_post", side_effect=responses):
|
||||
with pytest.raises(RegistrationError,
|
||||
match="missing verification_uri_complete"):
|
||||
begin_registration()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wait_for_registration_success — polling loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWaitForSuccess:
|
||||
|
||||
def test_returns_credentials_on_success(self):
|
||||
from hermes_cli.dingtalk_auth import wait_for_registration_success
|
||||
|
||||
responses = [
|
||||
{"status": "WAITING"},
|
||||
{"status": "WAITING"},
|
||||
{"status": "SUCCESS", "client_id": "cid-1", "client_secret": "sec-1"},
|
||||
]
|
||||
with patch("hermes_cli.dingtalk_auth.poll_registration", side_effect=responses), \
|
||||
patch("hermes_cli.dingtalk_auth.time.sleep"):
|
||||
cid, secret = wait_for_registration_success(
|
||||
device_code="dev", interval=0, expires_in=60
|
||||
)
|
||||
assert cid == "cid-1"
|
||||
assert secret == "sec-1"
|
||||
|
||||
def test_success_without_credentials_raises(self):
|
||||
from hermes_cli.dingtalk_auth import wait_for_registration_success, RegistrationError
|
||||
|
||||
with patch("hermes_cli.dingtalk_auth.poll_registration",
|
||||
return_value={"status": "SUCCESS", "client_id": "", "client_secret": ""}), \
|
||||
patch("hermes_cli.dingtalk_auth.time.sleep"):
|
||||
with pytest.raises(RegistrationError, match="credentials are missing"):
|
||||
wait_for_registration_success(
|
||||
device_code="dev", interval=0, expires_in=60
|
||||
)
|
||||
|
||||
def test_invokes_waiting_callback(self):
|
||||
from hermes_cli.dingtalk_auth import wait_for_registration_success
|
||||
|
||||
callback = MagicMock()
|
||||
responses = [
|
||||
{"status": "WAITING"},
|
||||
{"status": "WAITING"},
|
||||
{"status": "SUCCESS", "client_id": "cid", "client_secret": "sec"},
|
||||
]
|
||||
with patch("hermes_cli.dingtalk_auth.poll_registration", side_effect=responses), \
|
||||
patch("hermes_cli.dingtalk_auth.time.sleep"):
|
||||
wait_for_registration_success(
|
||||
device_code="dev", interval=0, expires_in=60, on_waiting=callback
|
||||
)
|
||||
assert callback.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QR rendering — terminal output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRenderQR:
|
||||
|
||||
def test_returns_false_when_qrcode_missing(self, monkeypatch):
|
||||
from hermes_cli import dingtalk_auth
|
||||
|
||||
# Simulate qrcode import failure
|
||||
monkeypatch.setitem(sys.modules, "qrcode", None)
|
||||
assert dingtalk_auth.render_qr_to_terminal("https://example.com") is False
|
||||
|
||||
def test_prints_when_qrcode_available(self, capsys):
|
||||
"""End-to-end: render a real QR and verify SOMETHING got printed."""
|
||||
try:
|
||||
import qrcode # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("qrcode library not available")
|
||||
|
||||
from hermes_cli.dingtalk_auth import render_qr_to_terminal
|
||||
result = render_qr_to_terminal("https://example.com/test")
|
||||
captured = capsys.readouterr()
|
||||
assert result is True
|
||||
assert len(captured.out) > 100 # rendered matrix is non-trivial
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration — env var overrides
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigOverrides:
|
||||
|
||||
def test_base_url_default(self, monkeypatch):
|
||||
monkeypatch.delenv("DINGTALK_REGISTRATION_BASE_URL", raising=False)
|
||||
# Force module reload to pick up current env
|
||||
import importlib
|
||||
import hermes_cli.dingtalk_auth as mod
|
||||
importlib.reload(mod)
|
||||
assert mod.REGISTRATION_BASE_URL == "https://oapi.dingtalk.com"
|
||||
|
||||
def test_base_url_override_via_env(self, monkeypatch):
|
||||
monkeypatch.setenv("DINGTALK_REGISTRATION_BASE_URL",
|
||||
"https://test.example.com/")
|
||||
import importlib
|
||||
import hermes_cli.dingtalk_auth as mod
|
||||
importlib.reload(mod)
|
||||
# Trailing slash stripped
|
||||
assert mod.REGISTRATION_BASE_URL == "https://test.example.com"
|
||||
|
||||
def test_source_default(self, monkeypatch):
|
||||
monkeypatch.delenv("DINGTALK_REGISTRATION_SOURCE", raising=False)
|
||||
import importlib
|
||||
import hermes_cli.dingtalk_auth as mod
|
||||
importlib.reload(mod)
|
||||
assert mod.REGISTRATION_SOURCE == "openClaw"
|
||||
@@ -93,59 +93,6 @@ class TestCopilotDotPreservation:
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ── Copilot model-name normalization (issue #6879 regression) ──────────
|
||||
|
||||
class TestCopilotModelNormalization:
|
||||
"""Copilot requires bare dot-notation model IDs.
|
||||
|
||||
Regression coverage for issue #6879 and the broken Copilot branch
|
||||
that previously left vendor-prefixed Anthropic IDs (e.g.
|
||||
``anthropic/claude-sonnet-4.6``) and dash-notation Claude IDs (e.g.
|
||||
``claude-sonnet-4-6``) unchanged, causing the Copilot API to reject
|
||||
the request with HTTP 400 "model_not_supported".
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("model,expected", [
|
||||
# Vendor-prefixed Anthropic IDs — prefix must be stripped.
|
||||
("anthropic/claude-opus-4.6", "claude-opus-4.6"),
|
||||
("anthropic/claude-sonnet-4.6", "claude-sonnet-4.6"),
|
||||
("anthropic/claude-sonnet-4.5", "claude-sonnet-4.5"),
|
||||
("anthropic/claude-haiku-4.5", "claude-haiku-4.5"),
|
||||
# Vendor-prefixed OpenAI IDs — prefix must be stripped.
|
||||
("openai/gpt-5.4", "gpt-5.4"),
|
||||
("openai/gpt-4o", "gpt-4o"),
|
||||
("openai/gpt-4o-mini", "gpt-4o-mini"),
|
||||
# Dash-notation Claude IDs — must be converted to dot-notation.
|
||||
("claude-opus-4-6", "claude-opus-4.6"),
|
||||
("claude-sonnet-4-6", "claude-sonnet-4.6"),
|
||||
("claude-sonnet-4-5", "claude-sonnet-4.5"),
|
||||
("claude-haiku-4-5", "claude-haiku-4.5"),
|
||||
# Combined: vendor-prefixed + dash-notation.
|
||||
("anthropic/claude-opus-4-6", "claude-opus-4.6"),
|
||||
("anthropic/claude-sonnet-4-6", "claude-sonnet-4.6"),
|
||||
# Already-canonical inputs pass through unchanged.
|
||||
("claude-sonnet-4.6", "claude-sonnet-4.6"),
|
||||
("gpt-5.4", "gpt-5.4"),
|
||||
("gpt-5-mini", "gpt-5-mini"),
|
||||
])
|
||||
def test_copilot_normalization(self, model, expected):
|
||||
assert normalize_model_for_provider(model, "copilot") == expected
|
||||
|
||||
@pytest.mark.parametrize("model,expected", [
|
||||
("anthropic/claude-sonnet-4.6", "claude-sonnet-4.6"),
|
||||
("claude-sonnet-4-6", "claude-sonnet-4.6"),
|
||||
("claude-opus-4-6", "claude-opus-4.6"),
|
||||
("openai/gpt-5.4", "gpt-5.4"),
|
||||
])
|
||||
def test_copilot_acp_normalization(self, model, expected):
|
||||
"""Copilot ACP shares the same API expectations as HTTP Copilot."""
|
||||
assert normalize_model_for_provider(model, "copilot-acp") == expected
|
||||
|
||||
def test_openai_codex_still_strips_openai_prefix(self):
|
||||
"""Regression: openai-codex must still strip the openai/ prefix."""
|
||||
assert normalize_model_for_provider("openai/gpt-5.4", "openai-codex") == "gpt-5.4"
|
||||
|
||||
|
||||
# ── Aggregator providers (regression) ──────────────────────────────────
|
||||
|
||||
class TestAggregatorProviders:
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
"""Tests for the prompt_toolkit /model picker scroll viewport.
|
||||
|
||||
Regression for: when a provider exposes many models (e.g. Ollama Cloud's
|
||||
36+), the picker rendered every choice into a Window with no max height,
|
||||
clipping the bottom border and any items past the terminal's last row.
|
||||
The viewport helper now caps visible items and slides the offset to keep
|
||||
the cursor on screen.
|
||||
"""
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
_compute = HermesCLI._compute_model_picker_viewport
|
||||
|
||||
|
||||
class TestPickerViewport:
|
||||
def test_short_list_no_scroll(self):
|
||||
offset, visible = _compute(selected=0, scroll_offset=0, n=5, term_rows=30)
|
||||
assert offset == 0
|
||||
assert visible == 5
|
||||
|
||||
def test_long_list_caps_visible_to_chrome_budget(self):
|
||||
# 30 rows minus reserved_below=6 minus panel_chrome=6 → max_visible=18.
|
||||
offset, visible = _compute(selected=0, scroll_offset=0, n=36, term_rows=30)
|
||||
assert visible == 18
|
||||
assert offset == 0
|
||||
|
||||
def test_cursor_past_window_scrolls_down(self):
|
||||
offset, visible = _compute(selected=22, scroll_offset=0, n=36, term_rows=30)
|
||||
assert visible == 18
|
||||
assert 22 in range(offset, offset + visible)
|
||||
|
||||
def test_cursor_above_window_scrolls_up(self):
|
||||
offset, visible = _compute(selected=3, scroll_offset=15, n=36, term_rows=30)
|
||||
assert offset == 3
|
||||
assert 3 in range(offset, offset + visible)
|
||||
|
||||
def test_offset_clamped_to_bottom(self):
|
||||
# Selected on the last item — offset must keep the visible window
|
||||
# full, not walk past the end of the list.
|
||||
offset, visible = _compute(selected=35, scroll_offset=0, n=36, term_rows=30)
|
||||
assert offset + visible == 36
|
||||
assert 35 in range(offset, offset + visible)
|
||||
|
||||
def test_tiny_terminal_uses_minimum_visible(self):
|
||||
# term_rows below the chrome budget falls back to the floor of 3 rows.
|
||||
_, visible = _compute(selected=0, scroll_offset=0, n=20, term_rows=10)
|
||||
assert visible == 3
|
||||
|
||||
def test_offset_recovers_after_stage_switch(self):
|
||||
# When the user backs out of the model stage and re-enters with
|
||||
# selected=0, a stale offset from the previous stage must collapse.
|
||||
offset, visible = _compute(selected=0, scroll_offset=25, n=36, term_rows=30)
|
||||
assert offset == 0
|
||||
assert 0 in range(offset, offset + visible)
|
||||
|
||||
def test_full_navigation_keeps_cursor_visible(self):
|
||||
offset = 0
|
||||
for cursor in list(range(36)) + list(range(35, -1, -1)):
|
||||
offset, visible = _compute(cursor, offset, n=36, term_rows=30)
|
||||
assert cursor in range(offset, offset + visible), (
|
||||
f"cursor={cursor} out of view: offset={offset} visible={visible}"
|
||||
)
|
||||
@@ -15,7 +15,7 @@ def test_opencode_go_appears_when_api_key_set():
|
||||
opencode_go = next((p for p in providers if p["slug"] == "opencode-go"), None)
|
||||
|
||||
assert opencode_go is not None, "opencode-go should appear when OPENCODE_GO_API_KEY is set"
|
||||
assert opencode_go["models"] == ["kimi-k2.5", "glm-5.1", "glm-5", "mimo-v2-pro", "mimo-v2-omni", "minimax-m2.7", "minimax-m2.5"]
|
||||
assert opencode_go["models"] == ["glm-5.1", "glm-5", "kimi-k2.5", "mimo-v2-pro", "mimo-v2-omni", "minimax-m2.7", "minimax-m2.5"]
|
||||
# opencode-go can appear as "built-in" (from PROVIDER_TO_MODELS_DEV when
|
||||
# models.dev is reachable) or "hermes" (from HERMES_OVERLAYS fallback when
|
||||
# the API is unavailable, e.g. in CI).
|
||||
|
||||
@@ -40,19 +40,6 @@ def test_get_platform_tools_preserves_explicit_empty_selection():
|
||||
assert enabled == set()
|
||||
|
||||
|
||||
def test_get_platform_tools_handles_null_platform_toolsets():
|
||||
"""YAML `platform_toolsets:` with no value parses as None — the old
|
||||
``config.get("platform_toolsets", {})`` pattern would then crash with
|
||||
``NoneType has no attribute 'get'`` on the next line. Guard against that.
|
||||
"""
|
||||
config = {"platform_toolsets": None}
|
||||
|
||||
enabled = _get_platform_tools(config, "cli")
|
||||
|
||||
# Falls through to defaults instead of raising
|
||||
assert enabled
|
||||
|
||||
|
||||
def test_platform_toolset_summary_uses_explicit_platform_list():
|
||||
config = {}
|
||||
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
"""Tests for Xiaomi MiMo provider support."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure dotenv doesn't interfere
|
||||
if "dotenv" not in sys.modules:
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
sys.modules["dotenv"] = fake_dotenv
|
||||
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY,
|
||||
resolve_provider,
|
||||
|
||||
@@ -31,31 +31,6 @@ def _isolate_env(tmp_path, monkeypatch):
|
||||
monkeypatch.delenv("RETAINDB_PROJECT", raising=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cap_retaindb_sleeps(monkeypatch):
|
||||
"""Cap production-code sleeps so background-thread tests run fast.
|
||||
|
||||
The retaindb ``_WriteQueue._flush_row`` does ``time.sleep(2)`` after
|
||||
errors. Across multiple tests that trigger the retry path, that adds
|
||||
up. Cap the module's bound ``time.sleep`` to 0.05s — tests don't care
|
||||
about the exact retry delay, only that it happens. The test file's
|
||||
own ``time.sleep`` stays real since it uses a different reference.
|
||||
"""
|
||||
try:
|
||||
from plugins.memory import retaindb as _retaindb
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
real_sleep = _retaindb.time.sleep
|
||||
|
||||
def _capped_sleep(seconds):
|
||||
return real_sleep(min(float(seconds), 0.05))
|
||||
|
||||
import types as _types
|
||||
fake_time = _types.SimpleNamespace(sleep=_capped_sleep, time=_retaindb.time.time)
|
||||
monkeypatch.setattr(_retaindb, "time", fake_time)
|
||||
|
||||
|
||||
# We need the repo root on sys.path so the plugin can import agent.memory_provider
|
||||
import sys
|
||||
_repo_root = str(Path(__file__).resolve().parents[2])
|
||||
@@ -155,18 +130,16 @@ class TestWriteQueue:
|
||||
def test_enqueue_creates_row(self, tmp_path):
|
||||
q, client, db_path = self._make_queue(tmp_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
# shutdown() blocks until the writer thread drains the queue — no need
|
||||
# to pre-sleep (the old 1s sleep was a just-in-case wait, but shutdown
|
||||
# does the right thing).
|
||||
# Give the writer thread a moment to process
|
||||
time.sleep(1)
|
||||
q.shutdown()
|
||||
# If ingest succeeded, the row should be deleted
|
||||
client.ingest_session.assert_called_once()
|
||||
|
||||
def test_enqueue_persists_to_sqlite(self, tmp_path):
|
||||
client = MagicMock()
|
||||
# Make ingest slow so the row is still in SQLite when we peek.
|
||||
# 0.5s is plenty — the test just needs the flush to still be in-flight.
|
||||
client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(0.5))
|
||||
# Make ingest hang so the row stays in SQLite
|
||||
client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(5))
|
||||
db_path = tmp_path / "test_queue.db"
|
||||
q = _WriteQueue(client, db_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}])
|
||||
@@ -181,7 +154,8 @@ class TestWriteQueue:
|
||||
def test_flush_deletes_row_on_success(self, tmp_path):
|
||||
q, client, db_path = self._make_queue(tmp_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
q.shutdown() # blocks until drain
|
||||
time.sleep(1)
|
||||
q.shutdown()
|
||||
# Row should be gone
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0]
|
||||
@@ -194,20 +168,14 @@ class TestWriteQueue:
|
||||
db_path = tmp_path / "test_queue.db"
|
||||
q = _WriteQueue(client, db_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
# Poll for the error to be recorded (max 2s), instead of a fixed 3s wait.
|
||||
deadline = time.time() + 2.0
|
||||
last_error = None
|
||||
while time.time() < deadline:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
row = conn.execute("SELECT last_error FROM pending").fetchone()
|
||||
conn.close()
|
||||
if row and row[0]:
|
||||
last_error = row[0]
|
||||
break
|
||||
time.sleep(0.05)
|
||||
time.sleep(3) # Allow retry + sleep(2) in _flush_row
|
||||
q.shutdown()
|
||||
assert last_error is not None
|
||||
assert "API down" in last_error
|
||||
# Row should still exist with error recorded
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
row = conn.execute("SELECT last_error FROM pending").fetchone()
|
||||
conn.close()
|
||||
assert row is not None
|
||||
assert "API down" in row[0]
|
||||
|
||||
def test_thread_local_connection_reuse(self, tmp_path):
|
||||
q, _, _ = self._make_queue(tmp_path)
|
||||
@@ -225,27 +193,14 @@ class TestWriteQueue:
|
||||
client1.ingest_session = MagicMock(side_effect=RuntimeError("fail"))
|
||||
q1 = _WriteQueue(client1, db_path)
|
||||
q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}])
|
||||
# Wait until the error is recorded (poll with short interval).
|
||||
deadline = time.time() + 2.0
|
||||
while time.time() < deadline:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
row = conn.execute("SELECT last_error FROM pending").fetchone()
|
||||
conn.close()
|
||||
if row and row[0]:
|
||||
break
|
||||
time.sleep(0.05)
|
||||
time.sleep(3)
|
||||
q1.shutdown()
|
||||
|
||||
# Now create a new queue — it should replay the pending rows
|
||||
client2 = MagicMock()
|
||||
client2.ingest_session = MagicMock(return_value={"status": "ok"})
|
||||
q2 = _WriteQueue(client2, db_path)
|
||||
# Poll for the replay to happen.
|
||||
deadline = time.time() + 2.0
|
||||
while time.time() < deadline:
|
||||
if client2.ingest_session.called:
|
||||
break
|
||||
time.sleep(0.05)
|
||||
time.sleep(2)
|
||||
q2.shutdown()
|
||||
|
||||
# The replayed row should have been ingested via client2
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
"""Fast-path fixtures shared across tests/run_agent/.
|
||||
|
||||
Many tests in this directory exercise the retry/backoff paths in the
|
||||
agent loop. Production code uses ``jittered_backoff(base_delay=5.0)``
|
||||
with a ``while time.time() < sleep_end`` loop — a single retry test
|
||||
spends 5+ seconds of real wall-clock time on backoff waits.
|
||||
|
||||
Mocking ``jittered_backoff`` to return 0.0 collapses the while-loop
|
||||
to a no-op (``time.time() < time.time() + 0`` is false immediately),
|
||||
which handles the most common case without touching ``time.sleep``.
|
||||
|
||||
We deliberately DO NOT mock ``time.sleep`` here — some tests
|
||||
(test_interrupt_propagation, test_primary_runtime_restore, etc.) use
|
||||
the real ``time.sleep`` for threading coordination or assert that it
|
||||
was called with specific values. Tests that want to additionally
|
||||
fast-path direct ``time.sleep(N)`` calls in production code should
|
||||
monkeypatch ``run_agent.time.sleep`` locally (see
|
||||
``test_anthropic_error_handling.py`` for the pattern).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fast_retry_backoff(monkeypatch):
|
||||
"""Short-circuit retry backoff for all tests in this directory."""
|
||||
try:
|
||||
import run_agent
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
|
||||
@@ -32,7 +32,6 @@ class TestGeneric400Heuristic:
|
||||
from run_agent import AIAgent
|
||||
a = AIAgent(
|
||||
api_key="test-key-12345",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
||||
@@ -19,24 +19,6 @@ import pytest
|
||||
|
||||
from agent.context_compressor import SUMMARY_PREFIX
|
||||
from run_agent import AIAgent
|
||||
import run_agent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fast backoff for compression retry tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_compression_sleep(monkeypatch):
|
||||
"""Short-circuit the 2s time.sleep between compression retries.
|
||||
|
||||
Production code has ``time.sleep(2)`` in multiple places after a 413/context
|
||||
compression, for rate-limit smoothing. Tests assert behavior, not timing.
|
||||
"""
|
||||
import time as _time
|
||||
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
|
||||
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -87,7 +69,6 @@ def agent():
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
||||
@@ -29,8 +29,6 @@ class TestFlushDeduplication:
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
session_db=session_db,
|
||||
@@ -273,8 +271,6 @@ class TestFlushIdxInit:
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -287,8 +283,6 @@ class TestFlushIdxInit:
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
||||
@@ -27,39 +27,6 @@ from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fast backoff for tests that exercise the retry loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_backoff_wait(monkeypatch):
|
||||
"""Short-circuit retry backoff so tests don't block on real wall-clock waits.
|
||||
|
||||
The production code uses jittered_backoff() with a 5s base delay plus a
|
||||
tight time.sleep(0.2) loop. Without this patch, each 429/500/529 retry
|
||||
test burns ~10s of real time on CI — across six tests that's ~60s for
|
||||
behavior we're not asserting against timing.
|
||||
|
||||
Tests assert retry counts and final results, never wait durations.
|
||||
"""
|
||||
import asyncio as _asyncio
|
||||
import time as _time
|
||||
|
||||
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
|
||||
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
|
||||
|
||||
# Also fast-path asyncio.sleep — the gateway's _run_agent path has
|
||||
# several await asyncio.sleep(...) calls that add real wall-clock time.
|
||||
_real_asyncio_sleep = _asyncio.sleep
|
||||
|
||||
async def _fast_sleep(delay=0, *args, **kwargs):
|
||||
# Yield to the event loop but skip the actual delay.
|
||||
await _real_asyncio_sleep(0)
|
||||
|
||||
monkeypatch.setattr(_asyncio, "sleep", _fast_sleep)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -37,8 +37,6 @@ class TestFlushAfterCompression:
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
session_db=session_db,
|
||||
|
||||
@@ -19,8 +19,6 @@ from run_agent import AIAgent
|
||||
def test_create_openai_client_does_not_mutate_input_kwargs(mock_openai):
|
||||
mock_openai.return_value = MagicMock()
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
||||
@@ -23,8 +23,6 @@ from run_agent import AIAgent
|
||||
|
||||
def _make_agent():
|
||||
return AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
||||
@@ -11,16 +11,6 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
import run_agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_fallback_wait(monkeypatch):
|
||||
"""Short-circuit time.sleep in fallback/recovery paths so tests don't
|
||||
block on the ``min(3 + retry_count, 8)`` wait before a primary retry."""
|
||||
import time as _time
|
||||
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
|
||||
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
|
||||
|
||||
|
||||
def _make_tool_defs(*names: str) -> list:
|
||||
@@ -46,7 +36,6 @@ def _make_agent(fallback_model=None):
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
||||
@@ -45,7 +45,6 @@ def test_plugin_engine_gets_context_length_on_init():
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
@@ -76,7 +75,6 @@ def test_plugin_engine_update_model_args():
|
||||
agent = AIAgent(
|
||||
model="openrouter/auto",
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
||||
@@ -19,7 +19,6 @@ def _make_agent(fallback_model=None):
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
||||
@@ -60,9 +60,6 @@ def _make_agent(monkeypatch, provider, api_mode="chat_completions", base_url="ht
|
||||
)
|
||||
if model:
|
||||
kwargs["model"] = model
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
return AIAgent(**kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -55,7 +55,6 @@ def agent():
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
@@ -77,7 +76,6 @@ def agent_with_memory_tool():
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
@@ -114,14 +112,12 @@ def test_aiagent_reuses_existing_errors_log_handler():
|
||||
):
|
||||
AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
@@ -495,7 +491,6 @@ class TestInit:
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-4o",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -547,7 +542,6 @@ class TestInit:
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
@@ -563,7 +557,6 @@ class TestInit:
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
@@ -701,7 +694,6 @@ class TestBuildSystemPrompt:
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
@@ -734,7 +726,6 @@ class TestToolUseEnforcementConfig:
|
||||
a = AIAgent(
|
||||
model=model,
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
@@ -831,7 +822,6 @@ class TestToolUseEnforcementConfig:
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
@@ -3443,7 +3433,7 @@ class TestAnthropicBaseUrlPassthrough:
|
||||
):
|
||||
mock_build.return_value = MagicMock()
|
||||
a = AIAgent(
|
||||
api_key="sk-ant...7890",
|
||||
api_key="sk-ant-api03-test1234567890",
|
||||
api_mode="anthropic_messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -3467,7 +3457,6 @@ class TestAnthropicCredentialRefresh:
|
||||
mock_build.side_effect = [old_client, new_client]
|
||||
agent = AIAgent(
|
||||
api_key="sk-ant-oat01-stale-token",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_mode="anthropic_messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -3498,7 +3487,6 @@ class TestAnthropicCredentialRefresh:
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="sk-ant-oat01-same-token",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_mode="anthropic_messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -3526,7 +3514,6 @@ class TestAnthropicCredentialRefresh:
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="sk-ant-oat01-current-token",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_mode="anthropic_messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
||||
@@ -12,15 +12,6 @@ sys.modules.setdefault("fal_client", types.SimpleNamespace())
|
||||
import run_agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_codex_backoff(monkeypatch):
|
||||
"""Short-circuit retry backoff so Codex retry tests don't block on real
|
||||
wall-clock waits (5s jittered_backoff base delay + tight time.sleep loop)."""
|
||||
import time as _time
|
||||
monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0)
|
||||
monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None)
|
||||
|
||||
|
||||
def _patch_agent_bootstrap(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
run_agent,
|
||||
|
||||
@@ -80,8 +80,6 @@ class TestStreamingAccumulator:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -122,8 +120,6 @@ class TestStreamingAccumulator:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -171,8 +167,6 @@ class TestStreamingAccumulator:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -211,8 +205,6 @@ class TestStreamingAccumulator:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -253,8 +245,6 @@ class TestStreamingCallbacks:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -287,8 +277,6 @@ class TestStreamingCallbacks:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -320,8 +308,6 @@ class TestStreamingCallbacks:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -360,8 +346,6 @@ class TestStreamingCallbacks:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -397,8 +381,6 @@ class TestStreamingCallbacks:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -446,8 +428,6 @@ class TestStreamingFallback:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -475,8 +455,6 @@ class TestStreamingFallback:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -499,8 +477,6 @@ class TestStreamingFallback:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -524,8 +500,6 @@ class TestStreamingFallback:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -568,8 +542,6 @@ class TestStreamingFallback:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -605,8 +577,6 @@ class TestStreamingFallback:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -649,8 +619,6 @@ class TestReasoningStreaming:
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -678,8 +646,6 @@ class TestHasStreamConsumers:
|
||||
def test_no_consumers(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -690,8 +656,6 @@ class TestHasStreamConsumers:
|
||||
def test_delta_callback_set(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -703,8 +667,6 @@ class TestHasStreamConsumers:
|
||||
def test_stream_callback_set(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -726,8 +688,6 @@ class TestCodexStreamCallbacks:
|
||||
deltas = []
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -769,8 +729,6 @@ class TestCodexStreamCallbacks:
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -834,8 +792,6 @@ class TestCodexStreamCallbacks:
|
||||
)
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -854,8 +810,6 @@ class TestCodexStreamCallbacks:
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
@@ -907,8 +861,6 @@ class TestAnthropicStreamCallbacks:
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
|
||||
@@ -22,7 +22,6 @@ def _make_agent(session_db, *, platform: str):
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
||||
@@ -27,10 +27,3 @@ def test_matrix_extra_linux_only_in_all():
|
||||
if "matrix" in dep and "linux" in dep
|
||||
]
|
||||
assert linux_gated, "expected hermes-agent[matrix] with sys_platform=='linux' marker in [all]"
|
||||
|
||||
|
||||
def test_messaging_extra_includes_qrcode_for_weixin_setup():
|
||||
optional_dependencies = _load_optional_dependencies()
|
||||
|
||||
messaging_extra = optional_dependencies["messaging"]
|
||||
assert any(dep.startswith("qrcode") for dep in messaging_extra)
|
||||
|
||||
@@ -152,34 +152,6 @@ class TestIsSafeUrl:
|
||||
# 100.0.0.1 is a global IP, not in CGNAT range
|
||||
assert is_safe_url("http://legit-host.example/") is True
|
||||
|
||||
def test_benchmark_ip_blocked_for_non_allowlisted_host(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("198.18.0.23", 0)),
|
||||
]):
|
||||
assert is_safe_url("https://example.com/file.jpg") is False
|
||||
|
||||
def test_qq_multimedia_hostname_allowed_with_benchmark_ip(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("198.18.0.23", 0)),
|
||||
]):
|
||||
assert is_safe_url("https://multimedia.nt.qq.com.cn/download?id=123") is True
|
||||
|
||||
def test_qq_multimedia_hostname_exception_is_exact_match(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("198.18.0.23", 0)),
|
||||
]):
|
||||
assert is_safe_url("https://sub.multimedia.nt.qq.com.cn/download?id=123") is False
|
||||
|
||||
def test_qq_multimedia_hostname_exception_requires_https(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("198.18.0.23", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://multimedia.nt.qq.com.cn/download?id=123") is False
|
||||
|
||||
def test_qq_multimedia_hostname_dns_failure_still_blocked(self):
|
||||
with patch("socket.getaddrinfo", side_effect=socket.gaierror("Name resolution failed")):
|
||||
assert is_safe_url("https://multimedia.nt.qq.com.cn/download?id=123") is False
|
||||
|
||||
|
||||
class TestIsBlockedIp:
|
||||
"""Direct tests for the _is_blocked_ip helper."""
|
||||
@@ -187,7 +159,7 @@ class TestIsBlockedIp:
|
||||
@pytest.mark.parametrize("ip_str", [
|
||||
"127.0.0.1", "10.0.0.1", "172.16.0.1", "192.168.1.1",
|
||||
"169.254.169.254", "0.0.0.0", "224.0.0.1", "255.255.255.255",
|
||||
"100.64.0.1", "100.100.100.100", "100.127.255.254", "198.18.0.23",
|
||||
"100.64.0.1", "100.100.100.100", "100.127.255.254",
|
||||
"::1", "fe80::1", "fc00::1", "fd12::1", "ff02::1",
|
||||
"::ffff:127.0.0.1", "::ffff:169.254.169.254",
|
||||
])
|
||||
|
||||
@@ -215,27 +215,7 @@ def _handle_send(args):
|
||||
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
# Weixin can be configured purely via .env; synthesize a pconfig so
|
||||
# send_message and cron delivery work without a gateway.yaml entry.
|
||||
if platform_name == "weixin":
|
||||
import os
|
||||
wx_token = os.getenv("WEIXIN_TOKEN", "").strip()
|
||||
wx_account = os.getenv("WEIXIN_ACCOUNT_ID", "").strip()
|
||||
if wx_token and wx_account:
|
||||
from gateway.config import PlatformConfig
|
||||
pconfig = PlatformConfig(
|
||||
enabled=True,
|
||||
token=wx_token,
|
||||
extra={
|
||||
"account_id": wx_account,
|
||||
"base_url": os.getenv("WEIXIN_BASE_URL", "").strip(),
|
||||
"cdn_base_url": os.getenv("WEIXIN_CDN_BASE_URL", "").strip(),
|
||||
},
|
||||
)
|
||||
else:
|
||||
return tool_error(f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/config.yaml or environment variables.")
|
||||
else:
|
||||
return tool_error(f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/config.yaml or environment variables.")
|
||||
return tool_error(f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/config.yaml or environment variables.")
|
||||
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
|
||||
@@ -245,12 +225,6 @@ def _handle_send(args):
|
||||
used_home_channel = False
|
||||
if not chat_id:
|
||||
home = config.get_home_channel(platform)
|
||||
if not home and platform_name == "weixin":
|
||||
import os
|
||||
wx_home = os.getenv("WEIXIN_HOME_CHANNEL", "").strip()
|
||||
if wx_home:
|
||||
from gateway.config import HomeChannel
|
||||
home = HomeChannel(platform=platform, chat_id=wx_home, name="Weixin Home")
|
||||
if home:
|
||||
chat_id = home.chat_id
|
||||
used_home_channel = True
|
||||
@@ -1300,7 +1274,7 @@ async def _send_qqbot(pconfig, chat_id, message):
|
||||
|
||||
# Step 2: Send message via REST
|
||||
headers = {
|
||||
"Authorization": f"QQBot {access_token}",
|
||||
"Authorization": f"QQBotAccessToken {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"https://api.sgroup.qq.com/channels/{chat_id}/messages"
|
||||
|
||||
+2
-23
@@ -29,13 +29,6 @@ _BLOCKED_HOSTNAMES = frozenset({
|
||||
"metadata.goog",
|
||||
})
|
||||
|
||||
# Exact HTTPS hostnames allowed to resolve to private/benchmark-space IPs.
|
||||
# This is intentionally narrow: QQ media downloads can legitimately resolve
|
||||
# to 198.18.0.0/15 behind local proxy/benchmark infrastructure.
|
||||
_TRUSTED_PRIVATE_IP_HOSTS = frozenset({
|
||||
"multimedia.nt.qq.com.cn",
|
||||
})
|
||||
|
||||
# 100.64.0.0/10 (CGNAT / Shared Address Space, RFC 6598) is NOT covered by
|
||||
# ipaddress.is_private — it returns False for both is_private and is_global.
|
||||
# Must be blocked explicitly. Used by carrier-grade NAT, Tailscale/WireGuard
|
||||
@@ -55,11 +48,6 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _allows_private_ip_resolution(hostname: str, scheme: str) -> bool:
|
||||
"""Return True when a trusted HTTPS hostname may bypass IP-class blocking."""
|
||||
return scheme == "https" and hostname in _TRUSTED_PRIVATE_IP_HOSTS
|
||||
|
||||
|
||||
def is_safe_url(url: str) -> bool:
|
||||
"""Return True if the URL target is not a private/internal address.
|
||||
|
||||
@@ -68,8 +56,7 @@ def is_safe_url(url: str) -> bool:
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = (parsed.hostname or "").strip().lower().rstrip(".")
|
||||
scheme = (parsed.scheme or "").strip().lower()
|
||||
hostname = (parsed.hostname or "").strip().lower()
|
||||
if not hostname:
|
||||
return False
|
||||
|
||||
@@ -78,8 +65,6 @@ def is_safe_url(url: str) -> bool:
|
||||
logger.warning("Blocked request to internal hostname: %s", hostname)
|
||||
return False
|
||||
|
||||
allow_private_ip = _allows_private_ip_resolution(hostname, scheme)
|
||||
|
||||
# Try to resolve and check IP
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||
@@ -96,19 +81,13 @@ def is_safe_url(url: str) -> bool:
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not allow_private_ip and _is_blocked_ip(ip):
|
||||
if _is_blocked_ip(ip):
|
||||
logger.warning(
|
||||
"Blocked request to private/internal address: %s -> %s",
|
||||
hostname, ip_str,
|
||||
)
|
||||
return False
|
||||
|
||||
if allow_private_ip:
|
||||
logger.debug(
|
||||
"Allowing trusted hostname despite private/internal resolution: %s",
|
||||
hostname,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
|
||||
@@ -196,10 +196,6 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI
|
||||
| `DISCORD_IGNORED_CHANNELS` | Comma-separated channel IDs where the bot never responds |
|
||||
| `DISCORD_NO_THREAD_CHANNELS` | Comma-separated channel IDs where bot responds without auto-threading |
|
||||
| `DISCORD_REPLY_TO_MODE` | Reply-reference behavior: `off`, `first` (default), or `all` |
|
||||
| `DISCORD_ALLOW_MENTION_EVERYONE` | Allow the bot to ping `@everyone`/`@here` (default: `false`). See [Mention Control](../user-guide/messaging/discord.md#mention-control). |
|
||||
| `DISCORD_ALLOW_MENTION_ROLES` | Allow the bot to ping `@role` mentions (default: `false`). |
|
||||
| `DISCORD_ALLOW_MENTION_USERS` | Allow the bot to ping individual `@user` mentions (default: `true`). |
|
||||
| `DISCORD_ALLOW_MENTION_REPLIED_USER` | Ping the author when replying to their message (default: `true`). |
|
||||
| `SLACK_BOT_TOKEN` | Slack bot token (`xoxb-...`) |
|
||||
| `SLACK_APP_TOKEN` | Slack app-level token (`xapp-...`, required for Socket Mode) |
|
||||
| `SLACK_ALLOWED_USERS` | Comma-separated Slack user IDs |
|
||||
|
||||
@@ -283,10 +283,6 @@ Discord behavior is controlled through two files: **`~/.hermes/.env`** for crede
|
||||
| `DISCORD_IGNORED_CHANNELS` | No | — | Comma-separated channel IDs where the bot **never** responds, even when `@mentioned`. Takes priority over all other channel settings. |
|
||||
| `DISCORD_NO_THREAD_CHANNELS` | No | — | Comma-separated channel IDs where the bot responds directly in the channel instead of creating a thread. Only relevant when `DISCORD_AUTO_THREAD` is `true`. |
|
||||
| `DISCORD_REPLY_TO_MODE` | No | `"first"` | Controls reply-reference behavior: `"off"` — never reply to the original message, `"first"` — reply-reference on the first message chunk only (default), `"all"` — reply-reference on every chunk. |
|
||||
| `DISCORD_ALLOW_MENTION_EVERYONE` | No | `false` | When `false` (default), the bot cannot ping `@everyone` or `@here` even if its response contains those tokens. Set to `true` to opt back in. See [Mention Control](#mention-control) below. |
|
||||
| `DISCORD_ALLOW_MENTION_ROLES` | No | `false` | When `false` (default), the bot cannot ping `@role` mentions. Set to `true` to allow. |
|
||||
| `DISCORD_ALLOW_MENTION_USERS` | No | `true` | When `true` (default), the bot can ping individual users by ID. |
|
||||
| `DISCORD_ALLOW_MENTION_REPLIED_USER` | No | `true` | When `true` (default), replying to a message pings the original author. |
|
||||
|
||||
### Config File (`config.yaml`)
|
||||
|
||||
@@ -302,11 +298,6 @@ discord:
|
||||
ignored_channels: [] # Channel IDs where bot never responds
|
||||
no_thread_channels: [] # Channel IDs where bot responds without threading
|
||||
channel_prompts: {} # Per-channel ephemeral system prompts
|
||||
allow_mentions: # What the bot is allowed to ping (safe defaults)
|
||||
everyone: false # @everyone / @here pings (default: false)
|
||||
roles: false # @role pings (default: false)
|
||||
users: true # @user pings (default: true)
|
||||
replied_user: true # reply-reference pings the author (default: true)
|
||||
|
||||
# Session isolation (applies to all gateway platforms, not just Discord)
|
||||
group_sessions_per_user: true # Isolate sessions per user in shared channels
|
||||
@@ -561,34 +552,6 @@ If you intentionally want a shared room conversation, leave it off — just expe
|
||||
Always set `DISCORD_ALLOWED_USERS` to restrict who can interact with the bot. Without it, the gateway denies all users by default as a safety measure. Only add User IDs of people you trust — authorized users have full access to the agent's capabilities, including tool use and system access.
|
||||
:::
|
||||
|
||||
### Mention Control
|
||||
|
||||
By default, Hermes blocks the bot from pinging `@everyone`, `@here`, and role mentions, even if its reply contains those tokens. This prevents a poorly-worded prompt or echoed user content from spamming a whole server. Individual `@user` pings and reply-reference pings (the little "replying to…" chip) stay enabled so normal conversation still works.
|
||||
|
||||
You can relax these defaults via either env vars or `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# ~/.hermes/config.yaml
|
||||
discord:
|
||||
allow_mentions:
|
||||
everyone: false # allow the bot to ping @everyone / @here
|
||||
roles: false # allow the bot to ping @role mentions
|
||||
users: true # allow the bot to ping individual @users
|
||||
replied_user: true # ping the author when replying to their message
|
||||
```
|
||||
|
||||
```bash
|
||||
# ~/.hermes/.env — env vars win over config.yaml
|
||||
DISCORD_ALLOW_MENTION_EVERYONE=false
|
||||
DISCORD_ALLOW_MENTION_ROLES=false
|
||||
DISCORD_ALLOW_MENTION_USERS=true
|
||||
DISCORD_ALLOW_MENTION_REPLIED_USER=true
|
||||
```
|
||||
|
||||
:::tip
|
||||
Leave `everyone` and `roles` at `false` unless you know exactly why you need them. It is very easy for an LLM to produce the string `@everyone` inside a normal-looking response; without this protection, that would notify every member of your server.
|
||||
:::
|
||||
|
||||
For more information on securing your Hermes Agent deployment, see the [Security Guide](../security.md).
|
||||
|
||||
|
||||
|
||||
@@ -16,14 +16,14 @@ This adapter is for **personal WeChat accounts** (微信). If you need enterpris
|
||||
|
||||
- A personal WeChat account
|
||||
- Python packages: `aiohttp` and `cryptography`
|
||||
- Terminal QR rendering is included when Hermes is installed with the `messaging` extra
|
||||
- The `qrcode` package is optional (for terminal QR rendering during setup)
|
||||
|
||||
Install the required dependencies:
|
||||
|
||||
```bash
|
||||
pip install aiohttp cryptography
|
||||
# Optional: for terminal QR code display
|
||||
pip install hermes-agent[messaging]
|
||||
pip install qrcode
|
||||
```
|
||||
|
||||
## Setup
|
||||
@@ -90,7 +90,7 @@ The adapter will restore saved credentials, connect to the iLink API, and begin
|
||||
- **Media support** — images, video, files, and voice messages
|
||||
- **AES-128-ECB encrypted CDN** — automatic encryption/decryption for all media transfers
|
||||
- **Context token persistence** — disk-backed reply continuity across restarts
|
||||
- **Markdown formatting** — preserves Markdown, including headers, tables, and code blocks, so WeChat clients that support Markdown can render it natively
|
||||
- **Markdown formatting** — headers, tables, and code blocks are reformatted for WeChat readability
|
||||
- **Smart message chunking** — messages stay as a single bubble when under the limit; only oversized payloads split at logical boundaries
|
||||
- **Typing indicators** — shows "typing…" status in the WeChat client while the agent processes
|
||||
- **SSRF protection** — outbound media URLs are validated before download
|
||||
@@ -206,12 +206,12 @@ This ensures reply continuity even after gateway restarts.
|
||||
|
||||
## Markdown Formatting
|
||||
|
||||
WeChat clients connected through the iLink Bot API can render Markdown directly, so the adapter preserves Markdown instead of rewriting it:
|
||||
WeChat's personal chat does not natively render full Markdown. The adapter reformats content for better readability:
|
||||
|
||||
- **Headers** stay as Markdown headings (`#`, `##`, ...)
|
||||
- **Tables** stay as Markdown tables
|
||||
- **Code fences** stay as fenced code blocks
|
||||
- **Excessive blank lines** are collapsed to double newlines outside fenced code blocks
|
||||
- **Headers** (`# Title`) → converted to `【Title】` (level 1) or `**Title**` (level 2+)
|
||||
- **Tables** → reformatted as labeled key-value lists (e.g., `- Column: Value`)
|
||||
- **Code fences** → preserved as-is (WeChat renders these adequately)
|
||||
- **Excessive blank lines** → collapsed to double newlines
|
||||
|
||||
## Message Chunking
|
||||
|
||||
@@ -296,4 +296,4 @@ Only one Weixin gateway instance can use a given token at a time. The adapter ac
|
||||
| Voice messages show as text | If WeChat provides a transcription, the adapter uses the text. This is expected behavior |
|
||||
| Messages appear duplicated | The adapter deduplicates by message ID. If you see duplicates, check if multiple gateway instances are running |
|
||||
| `iLink POST ... HTTP 4xx/5xx` | API error from the iLink service. Check your token validity and network connectivity |
|
||||
| Terminal QR code doesn't render | Reinstall with the messaging extra: `pip install hermes-agent[messaging]`. Alternatively, open the URL printed above the QR |
|
||||
| Terminal QR code doesn't render | Install `qrcode`: `pip install qrcode`. Alternatively, open the URL printed above the QR |
|
||||
|
||||
Reference in New Issue
Block a user