Compare commits
31 Commits
fix/messag
...
fix/defens
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
847ee20390 | ||
|
|
d2b10545db | ||
|
|
85993fbb5a | ||
|
|
fb20a9e120 | ||
|
|
21b823dd3b | ||
|
|
618ed2c65f | ||
|
|
9f81c11ba0 | ||
|
|
5301c01776 | ||
|
|
d81de2f3d8 | ||
|
|
1314b4b541 | ||
|
|
0878e5f4a8 | ||
|
|
72bcec0ce5 | ||
|
|
d604b9622c | ||
|
|
cf0dd777c8 | ||
|
|
ec272ca8be | ||
|
|
99a44d87dc | ||
|
|
16f38abd25 | ||
|
|
cac3c4d45f | ||
|
|
4167e2e294 | ||
|
|
6ddb9ee3e3 | ||
|
|
05aefeddc7 | ||
|
|
9db75fcfc2 | ||
|
|
1264275cc3 | ||
|
|
cd6dc4ef7e | ||
|
|
8cd4a96686 | ||
|
|
344f3771cb | ||
|
|
24282dceb1 | ||
|
|
1f0bb8742f | ||
|
|
0de75505f3 | ||
|
|
e5a244ad5d | ||
|
|
b111f2a779 |
@@ -1053,7 +1053,8 @@ def build_anthropic_kwargs(
|
||||
elif tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "none":
|
||||
pass # Don't send tool_choice — Anthropic will use tools if needed
|
||||
# Anthropic has no tool_choice "none" — omit tools entirely to prevent use
|
||||
kwargs.pop("tools", None)
|
||||
elif isinstance(tool_choice, str):
|
||||
# Specific tool name
|
||||
kwargs["tool_choice"] = {"type": "tool", "name": tool_choice}
|
||||
|
||||
@@ -706,6 +706,8 @@ def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[st
|
||||
|
||||
def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None."""
|
||||
global auxiliary_is_nous
|
||||
auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint,
|
||||
_try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
|
||||
@@ -313,7 +313,19 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||
|
||||
if summary:
|
||||
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
||||
summary_role = "user" if last_head_role in ("assistant", "tool") else "assistant"
|
||||
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
|
||||
# Pick a role that avoids consecutive same-role with both neighbors.
|
||||
# Priority: avoid colliding with head (already committed), then tail.
|
||||
if last_head_role in ("assistant", "tool"):
|
||||
summary_role = "user"
|
||||
else:
|
||||
summary_role = "assistant"
|
||||
# If the chosen role collides with the tail AND flipping wouldn't
|
||||
# collide with the head, flip it.
|
||||
if summary_role == first_tail_role:
|
||||
flipped = "assistant" if summary_role == "user" else "user"
|
||||
if flipped != last_head_role:
|
||||
summary_role = flipped
|
||||
compressed.append({"role": summary_role, "content": summary})
|
||||
else:
|
||||
if not self.quiet_mode:
|
||||
|
||||
@@ -94,10 +94,9 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"gpt-5": 128000,
|
||||
"gpt-5-codex": 128000,
|
||||
"gpt-5-nano": 128000,
|
||||
"claude-opus-4-6": 200000,
|
||||
# Bare model IDs without provider prefix (avoid duplicates with entries above)
|
||||
"claude-opus-4-5": 200000,
|
||||
"claude-opus-4-1": 200000,
|
||||
"claude-sonnet-4-6": 200000,
|
||||
"claude-sonnet-4-5": 200000,
|
||||
"claude-sonnet-4": 200000,
|
||||
"claude-haiku-4-5": 200000,
|
||||
@@ -108,11 +107,7 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"minimax-m2.5": 204800,
|
||||
"minimax-m2.5-free": 204800,
|
||||
"minimax-m2.1": 204800,
|
||||
"glm-5": 202752,
|
||||
"glm-4.7": 202752,
|
||||
"glm-4.6": 202752,
|
||||
"kimi-k2.5": 262144,
|
||||
"kimi-k2-thinking": 262144,
|
||||
"kimi-k2": 262144,
|
||||
"qwen3-coder": 32768,
|
||||
"big-pickle": 128000,
|
||||
@@ -266,8 +261,10 @@ def get_model_context_length(model: str, base_url: str = "") -> int:
|
||||
if model in metadata:
|
||||
return metadata[model].get("context_length", 128000)
|
||||
|
||||
# 3. Hardcoded defaults (fuzzy match)
|
||||
for default_model, length in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
# 3. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
||||
for default_model, length in sorted(
|
||||
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
||||
):
|
||||
if default_model in model or model in default_model:
|
||||
return length
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ Jobs are stored in ~/.hermes/cron/jobs.json
|
||||
Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
@@ -167,6 +168,10 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
try:
|
||||
# Parse and validate
|
||||
dt = datetime.fromisoformat(schedule.replace('Z', '+00:00'))
|
||||
# Make naive timestamps timezone-aware at parse time so the stored
|
||||
# value doesn't depend on the system timezone matching at check time.
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.astimezone() # Interpret as local timezone
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": dt.isoformat(),
|
||||
@@ -539,8 +544,8 @@ def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
immediately. This prevents a burst of missed jobs on gateway restart.
|
||||
"""
|
||||
now = _hermes_now()
|
||||
jobs = [_apply_skill_fields(j) for j in load_jobs()]
|
||||
raw_jobs = load_jobs() # For saving updates
|
||||
raw_jobs = load_jobs()
|
||||
jobs = [_apply_skill_fields(j) for j in copy.deepcopy(raw_jobs)]
|
||||
due = []
|
||||
needs_save = False
|
||||
|
||||
|
||||
@@ -8,8 +8,9 @@ Hooks are discovered from ~/.hermes/hooks/ directories, each containing:
|
||||
|
||||
Events:
|
||||
- gateway:startup -- Gateway process starts
|
||||
- session:start -- New session created
|
||||
- session:reset -- User ran /new or /reset
|
||||
- session:start -- New session created (first message of a new session)
|
||||
- session:end -- Session ends (user ran /new or /reset)
|
||||
- session:reset -- Session reset completed (new session entry created)
|
||||
- agent:start -- Agent begins processing a message
|
||||
- agent:step -- Each turn in the tool-calling loop
|
||||
- agent:end -- Agent finishes processing
|
||||
|
||||
@@ -220,6 +220,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
# Start the sync loop.
|
||||
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||
self._mark_connected()
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
|
||||
@@ -222,6 +222,7 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
|
||||
# Start WebSocket in background.
|
||||
self._ws_task = asyncio.create_task(self._ws_loop())
|
||||
self._mark_connected()
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
|
||||
@@ -2178,7 +2178,14 @@ class GatewayRunner:
|
||||
|
||||
# Reset the session
|
||||
new_entry = self.session_store.reset_session(session_key)
|
||||
|
||||
|
||||
# Emit session:end hook (session is ending)
|
||||
await self.hooks.emit("session:end", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
"user_id": source.user_id,
|
||||
"session_key": session_key,
|
||||
})
|
||||
|
||||
# Emit session:reset hook
|
||||
await self.hooks.emit("session:reset", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
|
||||
@@ -16,7 +16,6 @@ import os
|
||||
import platform
|
||||
import re
|
||||
import stat
|
||||
import sys
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -379,6 +378,7 @@ ENV_VARS_BY_VERSION: Dict[int, List[str]] = {
|
||||
4: ["VOICE_TOOLS_OPENAI_KEY", "ELEVENLABS_API_KEY"],
|
||||
5: ["WHATSAPP_ENABLED", "WHATSAPP_MODE", "WHATSAPP_ALLOWED_USERS",
|
||||
"SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", "SLACK_ALLOWED_USERS"],
|
||||
10: ["TAVILY_API_KEY"],
|
||||
}
|
||||
|
||||
# Required environment variables with metadata for migration prompts.
|
||||
@@ -574,6 +574,14 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "tool",
|
||||
"advanced": True,
|
||||
},
|
||||
"TAVILY_API_KEY": {
|
||||
"description": "Tavily API key for AI-native web search, extract, and crawl",
|
||||
"prompt": "Tavily API key",
|
||||
"url": "https://app.tavily.com/home",
|
||||
"tools": ["web_search", "web_extract", "web_crawl"],
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"BROWSERBASE_API_KEY": {
|
||||
"description": "Browserbase API key for cloud browser (optional — local browser works without this)",
|
||||
"prompt": "Browserbase API key",
|
||||
@@ -1516,6 +1524,7 @@ def show_config():
|
||||
("VOICE_TOOLS_OPENAI_KEY", "OpenAI (STT/TTS)"),
|
||||
("PARALLEL_API_KEY", "Parallel"),
|
||||
("FIRECRAWL_API_KEY", "Firecrawl"),
|
||||
("TAVILY_API_KEY", "Tavily"),
|
||||
("BROWSERBASE_API_KEY", "Browserbase"),
|
||||
("BROWSER_USE_API_KEY", "Browser Use"),
|
||||
("FAL_KEY", "FAL"),
|
||||
@@ -1664,7 +1673,8 @@ def set_config_value(key: str, value: str):
|
||||
# Check if it's an API key (goes to .env)
|
||||
api_keys = [
|
||||
'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
|
||||
'PARALLEL_API_KEY', 'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', 'BROWSER_USE_API_KEY',
|
||||
'PARALLEL_API_KEY', 'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'TAVILY_API_KEY',
|
||||
'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', 'BROWSER_USE_API_KEY',
|
||||
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
|
||||
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
|
||||
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',
|
||||
|
||||
@@ -1996,20 +1996,32 @@ def _update_via_zip(args):
|
||||
print(f"✗ ZIP update failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Reinstall Python dependencies
|
||||
# Reinstall Python dependencies (try .[all] first for optional extras,
|
||||
# fall back to . if extras fail — mirrors the install script behavior)
|
||||
print("→ Updating Python dependencies...")
|
||||
import subprocess
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin:
|
||||
subprocess.run(
|
||||
[uv_bin, "pip", "install", "-e", ".", "--quiet"],
|
||||
cwd=PROJECT_ROOT, check=True,
|
||||
env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}
|
||||
)
|
||||
uv_env = {**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}
|
||||
try:
|
||||
subprocess.run(
|
||||
[uv_bin, "pip", "install", "-e", ".[all]", "--quiet"],
|
||||
cwd=PROJECT_ROOT, check=True, env=uv_env,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
print(" ⚠ Optional extras failed, installing base dependencies...")
|
||||
subprocess.run(
|
||||
[uv_bin, "pip", "install", "-e", ".", "--quiet"],
|
||||
cwd=PROJECT_ROOT, check=True, env=uv_env,
|
||||
)
|
||||
else:
|
||||
venv_pip = PROJECT_ROOT / "venv" / ("Scripts" if sys.platform == "win32" else "bin") / "pip"
|
||||
if venv_pip.exists():
|
||||
subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
pip_cmd = [str(venv_pip)] if venv_pip.exists() else ["pip"]
|
||||
try:
|
||||
subprocess.run(pip_cmd + ["install", "-e", ".[all]", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
except subprocess.CalledProcessError:
|
||||
print(" ⚠ Optional extras failed, installing base dependencies...")
|
||||
subprocess.run(pip_cmd + ["install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
|
||||
# Sync skills
|
||||
try:
|
||||
@@ -2257,21 +2269,31 @@ def cmd_update(args):
|
||||
|
||||
_invalidate_update_cache()
|
||||
|
||||
# Reinstall Python dependencies (prefer uv for speed, fall back to pip)
|
||||
# Reinstall Python dependencies (try .[all] first for optional extras,
|
||||
# fall back to . if extras fail — mirrors the install script behavior)
|
||||
print("→ Updating Python dependencies...")
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin:
|
||||
subprocess.run(
|
||||
[uv_bin, "pip", "install", "-e", ".", "--quiet"],
|
||||
cwd=PROJECT_ROOT, check=True,
|
||||
env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}
|
||||
)
|
||||
uv_env = {**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}
|
||||
try:
|
||||
subprocess.run(
|
||||
[uv_bin, "pip", "install", "-e", ".[all]", "--quiet"],
|
||||
cwd=PROJECT_ROOT, check=True, env=uv_env,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
print(" ⚠ Optional extras failed, installing base dependencies...")
|
||||
subprocess.run(
|
||||
[uv_bin, "pip", "install", "-e", ".", "--quiet"],
|
||||
cwd=PROJECT_ROOT, check=True, env=uv_env,
|
||||
)
|
||||
else:
|
||||
venv_pip = PROJECT_ROOT / "venv" / ("Scripts" if sys.platform == "win32" else "bin") / "pip"
|
||||
if venv_pip.exists():
|
||||
subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
else:
|
||||
subprocess.run(["pip", "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
pip_cmd = [str(venv_pip)] if venv_pip.exists() else ["pip"]
|
||||
try:
|
||||
subprocess.run(pip_cmd + ["install", "-e", ".[all]", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
except subprocess.CalledProcessError:
|
||||
print(" ⚠ Optional extras failed, installing base dependencies...")
|
||||
subprocess.run(pip_cmd + ["install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
|
||||
# Check for Node.js deps
|
||||
if (PROJECT_ROOT / "package.json").exists():
|
||||
|
||||
@@ -444,11 +444,11 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
else:
|
||||
tool_status.append(("Mixture of Agents", False, "OPENROUTER_API_KEY"))
|
||||
|
||||
# Web tools (Parallel or Firecrawl)
|
||||
if get_env_value("PARALLEL_API_KEY") or get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL"):
|
||||
# Web tools (Parallel, Firecrawl, or Tavily)
|
||||
if get_env_value("PARALLEL_API_KEY") or get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL") or get_env_value("TAVILY_API_KEY"):
|
||||
tool_status.append(("Web Search & Extract", True, None))
|
||||
else:
|
||||
tool_status.append(("Web Search & Extract", False, "PARALLEL_API_KEY or FIRECRAWL_API_KEY"))
|
||||
tool_status.append(("Web Search & Extract", False, "PARALLEL_API_KEY, FIRECRAWL_API_KEY, or TAVILY_API_KEY"))
|
||||
|
||||
# Browser tools (local Chromium or Browserbase cloud)
|
||||
import shutil
|
||||
|
||||
@@ -120,6 +120,7 @@ def show_status(args):
|
||||
"MiniMax": "MINIMAX_API_KEY",
|
||||
"MiniMax-CN": "MINIMAX_CN_API_KEY",
|
||||
"Firecrawl": "FIRECRAWL_API_KEY",
|
||||
"Tavily": "TAVILY_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY", # Optional — local browser works without this
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
|
||||
@@ -170,6 +170,14 @@ TOOL_CATEGORIES = {
|
||||
{"key": "PARALLEL_API_KEY", "prompt": "Parallel API key", "url": "https://parallel.ai"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Tavily",
|
||||
"tag": "AI-native search, extract, and crawl",
|
||||
"web_backend": "tavily",
|
||||
"env_vars": [
|
||||
{"key": "TAVILY_API_KEY", "prompt": "Tavily API key", "url": "https://app.tavily.com/home"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Firecrawl Self-Hosted",
|
||||
"tag": "Free - run your own instance",
|
||||
@@ -851,6 +859,11 @@ def _reconfigure_provider(provider: dict, config: dict):
|
||||
config.get("browser", {}).pop("cloud_provider", None)
|
||||
_print_success(f" Browser set to local mode")
|
||||
|
||||
# Set web search backend in config if applicable
|
||||
if provider.get("web_backend"):
|
||||
config.setdefault("web", {})["backend"] = provider["web_backend"]
|
||||
_print_success(f" Web backend set to: {provider['web_backend']}")
|
||||
|
||||
if not env_vars:
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
return
|
||||
|
||||
@@ -350,11 +350,12 @@ class SessionDB:
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE id LIKE ? ESCAPE '\\' ORDER BY started_at DESC LIMIT 2",
|
||||
(f"{escaped}%",),
|
||||
)
|
||||
matches = [row["id"] for row in cursor.fetchall()]
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE id LIKE ? ESCAPE '\\' ORDER BY started_at DESC LIMIT 2",
|
||||
(f"{escaped}%",),
|
||||
)
|
||||
matches = [row["id"] for row in cursor.fetchall()]
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
return None
|
||||
|
||||
@@ -101,7 +101,7 @@ def _discover_tools():
|
||||
try:
|
||||
importlib.import_module(mod_name)
|
||||
except Exception as e:
|
||||
logger.debug("Could not import %s: %s", mod_name, e)
|
||||
logger.warning("Could not import tool module %s: %s", mod_name, e)
|
||||
|
||||
|
||||
_discover_tools()
|
||||
|
||||
209
run_agent.py
209
run_agent.py
@@ -1957,7 +1957,124 @@ class AIAgent:
|
||||
prompt_parts.append(PLATFORM_HINTS[platform_key])
|
||||
|
||||
return "\n\n".join(prompt_parts)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Pre/post-call guardrails (inspired by PR #1321 — @alireza78a)
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_call_id_static(tc) -> str:
|
||||
"""Extract call ID from a tool_call entry (dict or object)."""
|
||||
if isinstance(tc, dict):
|
||||
return tc.get("id", "") or ""
|
||||
return getattr(tc, "id", "") or ""
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_api_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Fix orphaned tool_call / tool_result pairs before every LLM call.
|
||||
|
||||
Runs unconditionally — not gated on whether the context compressor
|
||||
is present — so orphans from session loading or manual message
|
||||
manipulation are always caught.
|
||||
"""
|
||||
surviving_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = AIAgent._get_tool_call_id_static(tc)
|
||||
if cid:
|
||||
surviving_call_ids.add(cid)
|
||||
|
||||
result_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
cid = msg.get("tool_call_id")
|
||||
if cid:
|
||||
result_call_ids.add(cid)
|
||||
|
||||
# 1. Drop tool results with no matching assistant call
|
||||
orphaned_results = result_call_ids - surviving_call_ids
|
||||
if orphaned_results:
|
||||
messages = [
|
||||
m for m in messages
|
||||
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||
]
|
||||
logger.debug(
|
||||
"Pre-call sanitizer: removed %d orphaned tool result(s)",
|
||||
len(orphaned_results),
|
||||
)
|
||||
|
||||
# 2. Inject stub results for calls whose result was dropped
|
||||
missing_results = surviving_call_ids - result_call_ids
|
||||
if missing_results:
|
||||
patched: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = AIAgent._get_tool_call_id_static(tc)
|
||||
if cid in missing_results:
|
||||
patched.append({
|
||||
"role": "tool",
|
||||
"content": "[Result unavailable — see context summary above]",
|
||||
"tool_call_id": cid,
|
||||
})
|
||||
messages = patched
|
||||
logger.debug(
|
||||
"Pre-call sanitizer: added %d stub tool result(s)",
|
||||
len(missing_results),
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _cap_delegate_task_calls(tool_calls: list) -> list:
|
||||
"""Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN.
|
||||
|
||||
The delegate_tool caps the task list inside a single call, but the
|
||||
model can emit multiple separate delegate_task tool_calls in one
|
||||
turn. This truncates the excess, preserving all non-delegate calls.
|
||||
|
||||
Returns the original list if no truncation was needed.
|
||||
"""
|
||||
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
||||
delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task")
|
||||
if delegate_count <= MAX_CONCURRENT_CHILDREN:
|
||||
return tool_calls
|
||||
kept_delegates = 0
|
||||
truncated = []
|
||||
for tc in tool_calls:
|
||||
if tc.function.name == "delegate_task":
|
||||
if kept_delegates < MAX_CONCURRENT_CHILDREN:
|
||||
truncated.append(tc)
|
||||
kept_delegates += 1
|
||||
else:
|
||||
truncated.append(tc)
|
||||
logger.warning(
|
||||
"Truncated %d excess delegate_task call(s) to enforce "
|
||||
"MAX_CONCURRENT_CHILDREN=%d limit",
|
||||
delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN,
|
||||
)
|
||||
return truncated
|
||||
|
||||
@staticmethod
|
||||
def _deduplicate_tool_calls(tool_calls: list) -> list:
|
||||
"""Remove duplicate (tool_name, arguments) pairs within a single turn.
|
||||
|
||||
Only the first occurrence of each unique pair is kept.
|
||||
Returns the original list if no duplicates were found.
|
||||
"""
|
||||
seen: set = set()
|
||||
unique: list = []
|
||||
for tc in tool_calls:
|
||||
key = (tc.function.name, tc.function.arguments)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique.append(tc)
|
||||
else:
|
||||
logger.warning("Removed duplicate tool call: %s", tc.function.name)
|
||||
return unique if len(unique) < len(tool_calls) else tool_calls
|
||||
|
||||
def _repair_tool_call(self, tool_name: str) -> str | None:
|
||||
"""Attempt to repair a mismatched tool name before aborting.
|
||||
|
||||
@@ -4884,6 +5001,7 @@ class AIAgent:
|
||||
codex_ack_continuations = 0
|
||||
length_continue_retries = 0
|
||||
truncated_response_prefix = ""
|
||||
compression_attempts = 0
|
||||
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
@@ -4991,11 +5109,10 @@ class AIAgent:
|
||||
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
|
||||
|
||||
# Safety net: strip orphaned tool results / add stubs for missing
|
||||
# results before sending to the API. The compressor handles this
|
||||
# during compression, but orphans can also sneak in from session
|
||||
# loading or manual message manipulation.
|
||||
if hasattr(self, 'context_compressor') and self.context_compressor:
|
||||
api_messages = self.context_compressor._sanitize_tool_pairs(api_messages)
|
||||
# results before sending to the API. Runs unconditionally — not
|
||||
# gated on context_compressor — so orphans from session loading or
|
||||
# manual message manipulation are always caught.
|
||||
api_messages = self._sanitize_api_messages(api_messages)
|
||||
|
||||
# Calculate approximate request size for logging
|
||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||
@@ -5029,7 +5146,6 @@ class AIAgent:
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
compression_attempts = 0
|
||||
max_compression_attempts = 3
|
||||
codex_auth_retry_attempted = False
|
||||
anthropic_auth_retry_attempted = False
|
||||
@@ -5132,6 +5248,13 @@ class AIAgent:
|
||||
# This is often rate limiting or provider returning malformed response
|
||||
retry_count += 1
|
||||
|
||||
# Eager fallback: empty/malformed responses are a common
|
||||
# rate-limit symptom. Switch to fallback immediately
|
||||
# rather than retrying with extended backoff.
|
||||
if not self._fallback_activated and self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
|
||||
# Check for error field in response (some providers include this)
|
||||
error_msg = "Unknown"
|
||||
provider_name = "Unknown"
|
||||
@@ -5485,6 +5608,24 @@ class AIAgent:
|
||||
# A 413 is a payload-size error — the correct response is to
|
||||
# compress history and retry, not abort immediately.
|
||||
status_code = getattr(api_error, "status_code", None)
|
||||
|
||||
# Eager fallback for rate-limit errors (429 or quota exhaustion).
|
||||
# When a fallback model is configured, switch immediately instead
|
||||
# of burning through retries with exponential backoff -- the
|
||||
# primary provider won't recover within the retry window.
|
||||
is_rate_limited = (
|
||||
status_code == 429
|
||||
or "rate limit" in error_msg
|
||||
or "too many requests" in error_msg
|
||||
or "rate_limit" in error_msg
|
||||
or "usage limit" in error_msg
|
||||
or "quota" in error_msg
|
||||
)
|
||||
if is_rate_limited and not self._fallback_activated:
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
|
||||
is_payload_too_large = (
|
||||
status_code == 413
|
||||
or 'request entity too large' in error_msg
|
||||
@@ -5971,24 +6112,45 @@ class AIAgent:
|
||||
# Don't add anything to messages, just retry the API call
|
||||
continue
|
||||
else:
|
||||
# Instead of returning partial, inject a helpful message and let model recover
|
||||
self._vprint(f"{self.log_prefix}⚠️ Injecting recovery message for invalid JSON...")
|
||||
# Instead of returning partial, inject tool error results so the model can recover.
|
||||
# Using tool results (not user messages) preserves role alternation.
|
||||
self._vprint(f"{self.log_prefix}⚠️ Injecting recovery tool results for invalid JSON...")
|
||||
self._invalid_json_retries = 0 # Reset for next attempt
|
||||
|
||||
# Add a user message explaining the issue
|
||||
recovery_msg = (
|
||||
f"Your tool call to '{tool_name}' had invalid JSON arguments. "
|
||||
f"Error: {error_msg}. "
|
||||
f"For tools with no required parameters, use an empty object: {{}}. "
|
||||
f"Please either retry the tool call with valid JSON, or respond without using that tool."
|
||||
)
|
||||
recovery_dict = {"role": "user", "content": recovery_msg}
|
||||
messages.append(recovery_dict)
|
||||
# Append the assistant message with its (broken) tool_calls
|
||||
recovery_assistant = self._build_assistant_message(assistant_message, finish_reason)
|
||||
messages.append(recovery_assistant)
|
||||
|
||||
# Respond with tool error results for each tool call
|
||||
invalid_names = {name for name, _ in invalid_json_args}
|
||||
for tc in assistant_message.tool_calls:
|
||||
if tc.function.name in invalid_names:
|
||||
err = next(e for n, e in invalid_json_args if n == tc.function.name)
|
||||
tool_result = (
|
||||
f"Error: Invalid JSON arguments. {err}. "
|
||||
f"For tools with no required parameters, use an empty object: {{}}. "
|
||||
f"Please retry with valid JSON."
|
||||
)
|
||||
else:
|
||||
tool_result = "Skipped: other tool call in this response had invalid JSON."
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": tool_result,
|
||||
})
|
||||
continue
|
||||
|
||||
# Reset retry counter on successful JSON validation
|
||||
self._invalid_json_retries = 0
|
||||
|
||||
|
||||
# ── Post-call guardrails ──────────────────────────
|
||||
assistant_message.tool_calls = self._cap_delegate_task_calls(
|
||||
assistant_message.tool_calls
|
||||
)
|
||||
assistant_message.tool_calls = self._deduplicate_tool_calls(
|
||||
assistant_message.tool_calls
|
||||
)
|
||||
|
||||
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||
|
||||
# If this turn has both content AND tool_calls, capture the content
|
||||
@@ -6169,6 +6331,8 @@ class AIAgent:
|
||||
|
||||
if truncated_response_prefix:
|
||||
final_response = truncated_response_prefix + final_response
|
||||
truncated_response_prefix = ""
|
||||
length_continue_retries = 0
|
||||
|
||||
# Strip <think> blocks from user-facing response (keep raw in messages for trajectory)
|
||||
final_response = self._strip_think_blocks(final_response).strip()
|
||||
@@ -6220,10 +6384,11 @@ class AIAgent:
|
||||
|
||||
if not pending_handled:
|
||||
# Error happened before tool processing (e.g. response parsing).
|
||||
# Use a user-role message so the model can see what went wrong
|
||||
# without confusing the API with a fabricated assistant turn.
|
||||
# Choose role to avoid consecutive same-role messages.
|
||||
last_role = messages[-1].get("role") if messages else None
|
||||
err_role = "assistant" if last_role == "user" else "user"
|
||||
sys_err_msg = {
|
||||
"role": "user",
|
||||
"role": err_role,
|
||||
"content": f"[System error during processing: {error_msg}]",
|
||||
}
|
||||
messages.append(sys_err_msg)
|
||||
|
||||
@@ -316,6 +316,38 @@ class TestSanitizeEnvLines:
|
||||
assert fixes == 0
|
||||
|
||||
|
||||
class TestOptionalEnvVarsRegistry:
|
||||
"""Verify that key env vars are registered in OPTIONAL_ENV_VARS."""
|
||||
|
||||
def test_tavily_api_key_registered(self):
|
||||
"""TAVILY_API_KEY is listed in OPTIONAL_ENV_VARS."""
|
||||
from hermes_cli.config import OPTIONAL_ENV_VARS
|
||||
assert "TAVILY_API_KEY" in OPTIONAL_ENV_VARS
|
||||
|
||||
def test_tavily_api_key_is_tool_category(self):
|
||||
"""TAVILY_API_KEY is in the 'tool' category."""
|
||||
from hermes_cli.config import OPTIONAL_ENV_VARS
|
||||
assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["category"] == "tool"
|
||||
|
||||
def test_tavily_api_key_is_password(self):
|
||||
"""TAVILY_API_KEY is marked as password."""
|
||||
from hermes_cli.config import OPTIONAL_ENV_VARS
|
||||
assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["password"] is True
|
||||
|
||||
def test_tavily_api_key_has_url(self):
|
||||
"""TAVILY_API_KEY has a URL."""
|
||||
from hermes_cli.config import OPTIONAL_ENV_VARS
|
||||
assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["url"] == "https://app.tavily.com/home"
|
||||
|
||||
def test_tavily_in_env_vars_by_version(self):
|
||||
"""TAVILY_API_KEY is listed in ENV_VARS_BY_VERSION."""
|
||||
from hermes_cli.config import ENV_VARS_BY_VERSION
|
||||
all_vars = []
|
||||
for vars_list in ENV_VARS_BY_VERSION.values():
|
||||
all_vars.extend(vars_list)
|
||||
assert "TAVILY_API_KEY" in all_vars
|
||||
|
||||
|
||||
class TestAnthropicTokenMigration:
|
||||
"""Test that config version 8→9 clears ANTHROPIC_TOKEN."""
|
||||
|
||||
|
||||
14
tests/hermes_cli/test_status.py
Normal file
14
tests/hermes_cli/test_status.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from hermes_cli.status import show_status
|
||||
|
||||
|
||||
def test_show_status_includes_tavily_key(monkeypatch, capsys, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("TAVILY_API_KEY", "tvly-1234567890abcdef")
|
||||
|
||||
show_status(SimpleNamespace(all=False, deep=False))
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Tavily" in output
|
||||
assert "tvly...cdef" in output
|
||||
@@ -4,6 +4,7 @@ from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli import config as hermes_config
|
||||
from hermes_cli import main as hermes_main
|
||||
|
||||
|
||||
@@ -235,3 +236,82 @@ def test_stash_local_changes_if_needed_raises_when_stash_ref_missing(monkeypatch
|
||||
|
||||
with pytest.raises(CalledProcessError):
|
||||
hermes_main._stash_local_changes_if_needed(["git"], Path(tmp_path))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Update uses .[all] with fallback to .
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _setup_update_mocks(monkeypatch, tmp_path):
|
||||
"""Common setup for cmd_update tests."""
|
||||
(tmp_path / ".git").mkdir()
|
||||
monkeypatch.setattr(hermes_main, "PROJECT_ROOT", tmp_path)
|
||||
monkeypatch.setattr(hermes_main, "_stash_local_changes_if_needed", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(hermes_main, "_restore_stashed_changes", lambda *a, **kw: True)
|
||||
monkeypatch.setattr(hermes_config, "get_missing_env_vars", lambda required_only=True: [])
|
||||
monkeypatch.setattr(hermes_config, "get_missing_config_fields", lambda: [])
|
||||
monkeypatch.setattr(hermes_config, "check_config_version", lambda: (5, 5))
|
||||
monkeypatch.setattr(hermes_config, "migrate_config", lambda **kw: {"env_added": [], "config_added": []})
|
||||
|
||||
|
||||
def test_cmd_update_tries_extras_first_then_falls_back(monkeypatch, tmp_path):
|
||||
"""When .[all] fails, update should fall back to . instead of aborting."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
|
||||
|
||||
recorded = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
recorded.append(cmd)
|
||||
if cmd == ["git", "fetch", "origin"]:
|
||||
return SimpleNamespace(stdout="", stderr="", returncode=0)
|
||||
if cmd == ["git", "rev-parse", "--abbrev-ref", "HEAD"]:
|
||||
return SimpleNamespace(stdout="main\n", stderr="", returncode=0)
|
||||
if cmd == ["git", "rev-list", "HEAD..origin/main", "--count"]:
|
||||
return SimpleNamespace(stdout="1\n", stderr="", returncode=0)
|
||||
if cmd == ["git", "pull", "origin", "main"]:
|
||||
return SimpleNamespace(stdout="Updating\n", stderr="", returncode=0)
|
||||
# .[all] fails
|
||||
if ".[all]" in cmd:
|
||||
raise CalledProcessError(returncode=1, cmd=cmd)
|
||||
# bare . succeeds
|
||||
if cmd == ["/usr/bin/uv", "pip", "install", "-e", ".", "--quiet"]:
|
||||
return SimpleNamespace(returncode=0)
|
||||
return SimpleNamespace(returncode=0)
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
install_cmds = [c for c in recorded if "pip" in c and "install" in c]
|
||||
assert len(install_cmds) == 2
|
||||
assert ".[all]" in install_cmds[0]
|
||||
assert "." in install_cmds[1] and ".[all]" not in install_cmds[1]
|
||||
|
||||
|
||||
def test_cmd_update_succeeds_with_extras(monkeypatch, tmp_path):
|
||||
"""When .[all] succeeds, no fallback should be attempted."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
|
||||
|
||||
recorded = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
recorded.append(cmd)
|
||||
if cmd == ["git", "fetch", "origin"]:
|
||||
return SimpleNamespace(stdout="", stderr="", returncode=0)
|
||||
if cmd == ["git", "rev-parse", "--abbrev-ref", "HEAD"]:
|
||||
return SimpleNamespace(stdout="main\n", stderr="", returncode=0)
|
||||
if cmd == ["git", "rev-list", "HEAD..origin/main", "--count"]:
|
||||
return SimpleNamespace(stdout="1\n", stderr="", returncode=0)
|
||||
if cmd == ["git", "pull", "origin", "main"]:
|
||||
return SimpleNamespace(stdout="Updating\n", stderr="", returncode=0)
|
||||
return SimpleNamespace(returncode=0)
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
install_cmds = [c for c in recorded if "pip" in c and "install" in c]
|
||||
assert len(install_cmds) == 1
|
||||
assert ".[all]" in install_cmds[0]
|
||||
|
||||
263
tests/test_agent_guardrails.py
Normal file
263
tests/test_agent_guardrails.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Unit tests for AIAgent pre/post-LLM-call guardrails.
|
||||
|
||||
Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a):
|
||||
- _sanitize_api_messages() — Phase 1: orphaned tool pair repair
|
||||
- _cap_delegate_task_calls() — Phase 2a: subagent concurrency limit
|
||||
- _deduplicate_tool_calls() — Phase 2b: identical call deduplication
|
||||
"""
|
||||
|
||||
import types
|
||||
|
||||
from run_agent import AIAgent
|
||||
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_tc(name: str, arguments: str = "{}") -> types.SimpleNamespace:
|
||||
"""Create a minimal tool_call SimpleNamespace mirroring the OpenAI SDK object."""
|
||||
tc = types.SimpleNamespace()
|
||||
tc.function = types.SimpleNamespace(name=name, arguments=arguments)
|
||||
return tc
|
||||
|
||||
|
||||
def tool_result(call_id: str, content: str = "ok") -> dict:
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
def assistant_dict_call(call_id: str, name: str = "terminal") -> dict:
|
||||
"""Dict-style tool_call (as stored in message history)."""
|
||||
return {"id": call_id, "function": {"name": name, "arguments": "{}"}}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 1 — _sanitize_api_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSanitizeApiMessages:
|
||||
|
||||
def test_orphaned_result_removed(self):
|
||||
msgs = [
|
||||
{"role": "assistant", "tool_calls": [assistant_dict_call("c1")]},
|
||||
tool_result("c1"),
|
||||
tool_result("c_ORPHAN"),
|
||||
]
|
||||
out = AIAgent._sanitize_api_messages(msgs)
|
||||
assert len(out) == 2
|
||||
assert all(m.get("tool_call_id") != "c_ORPHAN" for m in out)
|
||||
|
||||
def test_orphaned_call_gets_stub_result(self):
|
||||
msgs = [
|
||||
{"role": "assistant", "tool_calls": [assistant_dict_call("c2")]},
|
||||
]
|
||||
out = AIAgent._sanitize_api_messages(msgs)
|
||||
assert len(out) == 2
|
||||
stub = out[1]
|
||||
assert stub["role"] == "tool"
|
||||
assert stub["tool_call_id"] == "c2"
|
||||
assert stub["content"]
|
||||
|
||||
def test_clean_messages_pass_through(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "tool_calls": [assistant_dict_call("c3")]},
|
||||
tool_result("c3"),
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
out = AIAgent._sanitize_api_messages(msgs)
|
||||
assert out == msgs
|
||||
|
||||
def test_mixed_orphaned_result_and_orphaned_call(self):
|
||||
msgs = [
|
||||
{"role": "assistant", "tool_calls": [
|
||||
assistant_dict_call("c4"),
|
||||
assistant_dict_call("c5"),
|
||||
]},
|
||||
tool_result("c4"),
|
||||
tool_result("c_DANGLING"),
|
||||
]
|
||||
out = AIAgent._sanitize_api_messages(msgs)
|
||||
ids = [m.get("tool_call_id") for m in out if m.get("role") == "tool"]
|
||||
assert "c_DANGLING" not in ids
|
||||
assert "c4" in ids
|
||||
assert "c5" in ids
|
||||
|
||||
def test_empty_list_is_safe(self):
|
||||
assert AIAgent._sanitize_api_messages([]) == []
|
||||
|
||||
def test_no_tool_messages(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
]
|
||||
out = AIAgent._sanitize_api_messages(msgs)
|
||||
assert out == msgs
|
||||
|
||||
def test_sdk_object_tool_calls(self):
|
||||
tc_obj = types.SimpleNamespace(id="c6", function=types.SimpleNamespace(
|
||||
name="terminal", arguments="{}"
|
||||
))
|
||||
msgs = [
|
||||
{"role": "assistant", "tool_calls": [tc_obj]},
|
||||
]
|
||||
out = AIAgent._sanitize_api_messages(msgs)
|
||||
assert len(out) == 2
|
||||
assert out[1]["tool_call_id"] == "c6"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2a — _cap_delegate_task_calls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCapDelegateTaskCalls:
|
||||
|
||||
def test_excess_delegates_truncated(self):
|
||||
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
|
||||
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||
delegate_count = sum(1 for tc in out if tc.function.name == "delegate_task")
|
||||
assert delegate_count == MAX_CONCURRENT_CHILDREN
|
||||
|
||||
def test_non_delegate_calls_preserved(self):
|
||||
tcs = (
|
||||
[make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 1)]
|
||||
+ [make_tc("terminal"), make_tc("web_search")]
|
||||
)
|
||||
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||
names = [tc.function.name for tc in out]
|
||||
assert "terminal" in names
|
||||
assert "web_search" in names
|
||||
|
||||
def test_at_limit_passes_through(self):
|
||||
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN)]
|
||||
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||
assert out is tcs
|
||||
|
||||
def test_below_limit_passes_through(self):
|
||||
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN - 1)]
|
||||
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||
assert out is tcs
|
||||
|
||||
def test_no_delegate_calls_unchanged(self):
|
||||
tcs = [make_tc("terminal"), make_tc("web_search")]
|
||||
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||
assert out is tcs
|
||||
|
||||
def test_empty_list_safe(self):
|
||||
assert AIAgent._cap_delegate_task_calls([]) == []
|
||||
|
||||
def test_original_list_not_mutated(self):
|
||||
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
|
||||
original_len = len(tcs)
|
||||
AIAgent._cap_delegate_task_calls(tcs)
|
||||
assert len(tcs) == original_len
|
||||
|
||||
def test_interleaved_order_preserved(self):
|
||||
delegates = [make_tc("delegate_task", f'{{"task":"{i}"}}')
|
||||
for i in range(MAX_CONCURRENT_CHILDREN + 1)]
|
||||
t1 = make_tc("terminal", '{"cmd":"ls"}')
|
||||
w1 = make_tc("web_search", '{"q":"x"}')
|
||||
tcs = [delegates[0], t1, delegates[1], w1] + delegates[2:]
|
||||
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||
expected = [delegates[0], t1, delegates[1], w1] + delegates[2:MAX_CONCURRENT_CHILDREN]
|
||||
assert len(out) == len(expected)
|
||||
for i, (actual, exp) in enumerate(zip(out, expected)):
|
||||
assert actual is exp, f"mismatch at index {i}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2b — _deduplicate_tool_calls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeduplicateToolCalls:
|
||||
|
||||
def test_duplicate_pair_deduplicated(self):
|
||||
tcs = [
|
||||
make_tc("web_search", '{"query":"foo"}'),
|
||||
make_tc("web_search", '{"query":"foo"}'),
|
||||
]
|
||||
out = AIAgent._deduplicate_tool_calls(tcs)
|
||||
assert len(out) == 1
|
||||
|
||||
def test_multiple_duplicates(self):
|
||||
tcs = [
|
||||
make_tc("web_search", '{"q":"a"}'),
|
||||
make_tc("web_search", '{"q":"a"}'),
|
||||
make_tc("terminal", '{"cmd":"ls"}'),
|
||||
make_tc("terminal", '{"cmd":"ls"}'),
|
||||
make_tc("terminal", '{"cmd":"pwd"}'),
|
||||
]
|
||||
out = AIAgent._deduplicate_tool_calls(tcs)
|
||||
assert len(out) == 3
|
||||
|
||||
def test_same_tool_different_args_kept(self):
|
||||
tcs = [
|
||||
make_tc("terminal", '{"cmd":"ls"}'),
|
||||
make_tc("terminal", '{"cmd":"pwd"}'),
|
||||
]
|
||||
out = AIAgent._deduplicate_tool_calls(tcs)
|
||||
assert out is tcs
|
||||
|
||||
def test_different_tools_same_args_kept(self):
|
||||
tcs = [
|
||||
make_tc("tool_a", '{"x":1}'),
|
||||
make_tc("tool_b", '{"x":1}'),
|
||||
]
|
||||
out = AIAgent._deduplicate_tool_calls(tcs)
|
||||
assert out is tcs
|
||||
|
||||
def test_clean_list_unchanged(self):
|
||||
tcs = [
|
||||
make_tc("web_search", '{"q":"x"}'),
|
||||
make_tc("terminal", '{"cmd":"ls"}'),
|
||||
]
|
||||
out = AIAgent._deduplicate_tool_calls(tcs)
|
||||
assert out is tcs
|
||||
|
||||
def test_empty_list_safe(self):
|
||||
assert AIAgent._deduplicate_tool_calls([]) == []
|
||||
|
||||
def test_first_occurrence_kept(self):
|
||||
tc1 = make_tc("terminal", '{"cmd":"ls"}')
|
||||
tc2 = make_tc("terminal", '{"cmd":"ls"}')
|
||||
out = AIAgent._deduplicate_tool_calls([tc1, tc2])
|
||||
assert len(out) == 1
|
||||
assert out[0] is tc1
|
||||
|
||||
def test_original_list_not_mutated(self):
|
||||
tcs = [
|
||||
make_tc("web_search", '{"q":"dup"}'),
|
||||
make_tc("web_search", '{"q":"dup"}'),
|
||||
]
|
||||
original_len = len(tcs)
|
||||
AIAgent._deduplicate_tool_calls(tcs)
|
||||
assert len(tcs) == original_len
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_tool_call_id_static
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetToolCallIdStatic:
|
||||
|
||||
def test_dict_with_valid_id(self):
|
||||
assert AIAgent._get_tool_call_id_static({"id": "call_123"}) == "call_123"
|
||||
|
||||
def test_dict_with_none_id(self):
|
||||
assert AIAgent._get_tool_call_id_static({"id": None}) == ""
|
||||
|
||||
def test_dict_without_id_key(self):
|
||||
assert AIAgent._get_tool_call_id_static({"function": {}}) == ""
|
||||
|
||||
def test_object_with_valid_id(self):
|
||||
tc = types.SimpleNamespace(id="call_456")
|
||||
assert AIAgent._get_tool_call_id_static(tc) == "call_456"
|
||||
|
||||
def test_object_with_none_id(self):
|
||||
tc = types.SimpleNamespace(id=None)
|
||||
assert AIAgent._get_tool_call_id_static(tc) == ""
|
||||
|
||||
def test_object_without_id_attr(self):
|
||||
tc = types.SimpleNamespace()
|
||||
assert AIAgent._get_tool_call_id_static(tc) == ""
|
||||
@@ -130,7 +130,7 @@ class TestBackendSelection:
|
||||
setups.
|
||||
"""
|
||||
|
||||
_ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL")
|
||||
_ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", "TAVILY_API_KEY")
|
||||
|
||||
def setup_method(self):
|
||||
for key in self._ENV_KEYS:
|
||||
@@ -155,12 +155,31 @@ class TestBackendSelection:
|
||||
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
|
||||
assert _get_backend() == "firecrawl"
|
||||
|
||||
def test_config_tavily(self):
|
||||
"""web.backend=tavily in config → 'tavily' regardless of other keys."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={"backend": "tavily"}):
|
||||
assert _get_backend() == "tavily"
|
||||
|
||||
def test_config_tavily_overrides_env_keys(self):
|
||||
"""web.backend=tavily in config → 'tavily' even if Firecrawl key set."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={"backend": "tavily"}), \
|
||||
patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
|
||||
assert _get_backend() == "tavily"
|
||||
|
||||
def test_config_case_insensitive(self):
|
||||
"""web.backend=Parallel (mixed case) → 'parallel'."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={"backend": "Parallel"}):
|
||||
assert _get_backend() == "parallel"
|
||||
|
||||
def test_config_tavily_case_insensitive(self):
|
||||
"""web.backend=Tavily (mixed case) → 'tavily'."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={"backend": "Tavily"}):
|
||||
assert _get_backend() == "tavily"
|
||||
|
||||
# ── Fallback (no web.backend in config) ───────────────────────────
|
||||
|
||||
def test_fallback_parallel_only_key(self):
|
||||
@@ -170,6 +189,28 @@ class TestBackendSelection:
|
||||
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
|
||||
assert _get_backend() == "parallel"
|
||||
|
||||
def test_fallback_tavily_only_key(self):
|
||||
"""Only TAVILY_API_KEY set → 'tavily'."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={}), \
|
||||
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}):
|
||||
assert _get_backend() == "tavily"
|
||||
|
||||
def test_fallback_tavily_with_firecrawl_prefers_firecrawl(self):
|
||||
"""Tavily + Firecrawl keys, no config → 'firecrawl' (backward compat)."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={}), \
|
||||
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test", "FIRECRAWL_API_KEY": "fc-test"}):
|
||||
assert _get_backend() == "firecrawl"
|
||||
|
||||
def test_fallback_tavily_with_parallel_prefers_parallel(self):
|
||||
"""Tavily + Parallel keys, no config → 'parallel' (Parallel takes priority over Tavily)."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={}), \
|
||||
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test", "PARALLEL_API_KEY": "par-test"}):
|
||||
# Parallel + no Firecrawl → parallel
|
||||
assert _get_backend() == "parallel"
|
||||
|
||||
def test_fallback_both_keys_defaults_to_firecrawl(self):
|
||||
"""Both keys set, no config → 'firecrawl' (backward compat)."""
|
||||
from tools.web_tools import _get_backend
|
||||
@@ -193,7 +234,7 @@ class TestBackendSelection:
|
||||
def test_invalid_config_falls_through_to_fallback(self):
|
||||
"""web.backend=invalid → ignored, uses key-based fallback."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={"backend": "tavily"}), \
|
||||
with patch("tools.web_tools._load_web_config", return_value={"backend": "nonexistent"}), \
|
||||
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
|
||||
assert _get_backend() == "parallel"
|
||||
|
||||
@@ -238,7 +279,7 @@ class TestParallelClientConfig:
|
||||
class TestCheckWebApiKey:
|
||||
"""Test suite for check_web_api_key() unified availability check."""
|
||||
|
||||
_ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL")
|
||||
_ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", "TAVILY_API_KEY")
|
||||
|
||||
def setup_method(self):
|
||||
for key in self._ENV_KEYS:
|
||||
@@ -263,6 +304,11 @@ class TestCheckWebApiKey:
|
||||
from tools.web_tools import check_web_api_key
|
||||
assert check_web_api_key() is True
|
||||
|
||||
def test_tavily_key_only(self):
|
||||
with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}):
|
||||
from tools.web_tools import check_web_api_key
|
||||
assert check_web_api_key() is True
|
||||
|
||||
def test_no_keys_returns_false(self):
|
||||
from tools.web_tools import check_web_api_key
|
||||
assert check_web_api_key() is False
|
||||
@@ -274,3 +320,12 @@ class TestCheckWebApiKey:
|
||||
}):
|
||||
from tools.web_tools import check_web_api_key
|
||||
assert check_web_api_key() is True
|
||||
|
||||
def test_all_three_keys_returns_true(self):
|
||||
with patch.dict(os.environ, {
|
||||
"PARALLEL_API_KEY": "test-key",
|
||||
"FIRECRAWL_API_KEY": "fc-test",
|
||||
"TAVILY_API_KEY": "tvly-test",
|
||||
}):
|
||||
from tools.web_tools import check_web_api_key
|
||||
assert check_web_api_key() is True
|
||||
|
||||
255
tests/tools/test_web_tools_tavily.py
Normal file
255
tests/tools/test_web_tools_tavily.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Tests for Tavily web backend integration.
|
||||
|
||||
Coverage:
|
||||
_tavily_request() — API key handling, endpoint construction, error propagation.
|
||||
_normalize_tavily_search_results() — search response normalization.
|
||||
_normalize_tavily_documents() — extract/crawl response normalization, failed_results.
|
||||
web_search_tool / web_extract_tool / web_crawl_tool — Tavily dispatch paths.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
# ─── _tavily_request ─────────────────────────────────────────────────────────
|
||||
|
||||
class TestTavilyRequest:
|
||||
"""Test suite for the _tavily_request helper."""
|
||||
|
||||
def test_raises_without_api_key(self):
|
||||
"""No TAVILY_API_KEY → ValueError with guidance."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("TAVILY_API_KEY", None)
|
||||
from tools.web_tools import _tavily_request
|
||||
with pytest.raises(ValueError, match="TAVILY_API_KEY"):
|
||||
_tavily_request("search", {"query": "test"})
|
||||
|
||||
def test_posts_with_api_key_in_body(self):
|
||||
"""api_key is injected into the JSON payload."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"results": []}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test-key"}):
|
||||
with patch("tools.web_tools.httpx.post", return_value=mock_response) as mock_post:
|
||||
from tools.web_tools import _tavily_request
|
||||
result = _tavily_request("search", {"query": "hello"})
|
||||
|
||||
mock_post.assert_called_once()
|
||||
call_kwargs = mock_post.call_args
|
||||
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
||||
assert payload["api_key"] == "tvly-test-key"
|
||||
assert payload["query"] == "hello"
|
||||
assert "api.tavily.com/search" in call_kwargs.args[0]
|
||||
|
||||
def test_raises_on_http_error(self):
|
||||
"""Non-2xx responses propagate as httpx.HTTPStatusError."""
|
||||
import httpx as _httpx
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = _httpx.HTTPStatusError(
|
||||
"401 Unauthorized", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-bad-key"}):
|
||||
with patch("tools.web_tools.httpx.post", return_value=mock_response):
|
||||
from tools.web_tools import _tavily_request
|
||||
with pytest.raises(_httpx.HTTPStatusError):
|
||||
_tavily_request("search", {"query": "test"})
|
||||
|
||||
|
||||
# ─── _normalize_tavily_search_results ─────────────────────────────────────────
|
||||
|
||||
class TestNormalizeTavilySearchResults:
|
||||
"""Test search result normalization."""
|
||||
|
||||
def test_basic_normalization(self):
|
||||
from tools.web_tools import _normalize_tavily_search_results
|
||||
raw = {
|
||||
"results": [
|
||||
{"title": "Python Docs", "url": "https://docs.python.org", "content": "Official docs", "score": 0.9},
|
||||
{"title": "Tutorial", "url": "https://example.com", "content": "A tutorial", "score": 0.8},
|
||||
]
|
||||
}
|
||||
result = _normalize_tavily_search_results(raw)
|
||||
assert result["success"] is True
|
||||
web = result["data"]["web"]
|
||||
assert len(web) == 2
|
||||
assert web[0]["title"] == "Python Docs"
|
||||
assert web[0]["url"] == "https://docs.python.org"
|
||||
assert web[0]["description"] == "Official docs"
|
||||
assert web[0]["position"] == 1
|
||||
assert web[1]["position"] == 2
|
||||
|
||||
def test_empty_results(self):
|
||||
from tools.web_tools import _normalize_tavily_search_results
|
||||
result = _normalize_tavily_search_results({"results": []})
|
||||
assert result["success"] is True
|
||||
assert result["data"]["web"] == []
|
||||
|
||||
def test_missing_fields(self):
|
||||
from tools.web_tools import _normalize_tavily_search_results
|
||||
result = _normalize_tavily_search_results({"results": [{}]})
|
||||
web = result["data"]["web"]
|
||||
assert web[0]["title"] == ""
|
||||
assert web[0]["url"] == ""
|
||||
assert web[0]["description"] == ""
|
||||
|
||||
|
||||
# ─── _normalize_tavily_documents ──────────────────────────────────────────────
|
||||
|
||||
class TestNormalizeTavilyDocuments:
|
||||
"""Test extract/crawl document normalization."""
|
||||
|
||||
def test_basic_document(self):
|
||||
from tools.web_tools import _normalize_tavily_documents
|
||||
raw = {
|
||||
"results": [{
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"raw_content": "Full page content here",
|
||||
}]
|
||||
}
|
||||
docs = _normalize_tavily_documents(raw)
|
||||
assert len(docs) == 1
|
||||
assert docs[0]["url"] == "https://example.com"
|
||||
assert docs[0]["title"] == "Example"
|
||||
assert docs[0]["content"] == "Full page content here"
|
||||
assert docs[0]["raw_content"] == "Full page content here"
|
||||
assert docs[0]["metadata"]["sourceURL"] == "https://example.com"
|
||||
|
||||
def test_falls_back_to_content_when_no_raw_content(self):
|
||||
from tools.web_tools import _normalize_tavily_documents
|
||||
raw = {"results": [{"url": "https://example.com", "content": "Snippet"}]}
|
||||
docs = _normalize_tavily_documents(raw)
|
||||
assert docs[0]["content"] == "Snippet"
|
||||
|
||||
def test_failed_results_included(self):
|
||||
from tools.web_tools import _normalize_tavily_documents
|
||||
raw = {
|
||||
"results": [],
|
||||
"failed_results": [
|
||||
{"url": "https://fail.com", "error": "timeout"},
|
||||
],
|
||||
}
|
||||
docs = _normalize_tavily_documents(raw)
|
||||
assert len(docs) == 1
|
||||
assert docs[0]["url"] == "https://fail.com"
|
||||
assert docs[0]["error"] == "timeout"
|
||||
assert docs[0]["content"] == ""
|
||||
|
||||
def test_failed_urls_included(self):
|
||||
from tools.web_tools import _normalize_tavily_documents
|
||||
raw = {
|
||||
"results": [],
|
||||
"failed_urls": ["https://bad.com"],
|
||||
}
|
||||
docs = _normalize_tavily_documents(raw)
|
||||
assert len(docs) == 1
|
||||
assert docs[0]["url"] == "https://bad.com"
|
||||
assert docs[0]["error"] == "extraction failed"
|
||||
|
||||
def test_fallback_url(self):
|
||||
from tools.web_tools import _normalize_tavily_documents
|
||||
raw = {"results": [{"content": "data"}]}
|
||||
docs = _normalize_tavily_documents(raw, fallback_url="https://fallback.com")
|
||||
assert docs[0]["url"] == "https://fallback.com"
|
||||
|
||||
|
||||
# ─── web_search_tool (Tavily dispatch) ────────────────────────────────────────
|
||||
|
||||
class TestWebSearchTavily:
|
||||
"""Test web_search_tool dispatch to Tavily."""
|
||||
|
||||
def test_search_dispatches_to_tavily(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"results": [{"title": "Result", "url": "https://r.com", "content": "desc", "score": 0.9}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("tools.web_tools._get_backend", return_value="tavily"), \
|
||||
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
|
||||
patch("tools.web_tools.httpx.post", return_value=mock_response), \
|
||||
patch("tools.interrupt.is_interrupted", return_value=False):
|
||||
from tools.web_tools import web_search_tool
|
||||
result = json.loads(web_search_tool("test query", limit=3))
|
||||
assert result["success"] is True
|
||||
assert len(result["data"]["web"]) == 1
|
||||
assert result["data"]["web"][0]["title"] == "Result"
|
||||
|
||||
|
||||
# ─── web_extract_tool (Tavily dispatch) ───────────────────────────────────────
|
||||
|
||||
class TestWebExtractTavily:
|
||||
"""Test web_extract_tool dispatch to Tavily."""
|
||||
|
||||
def test_extract_dispatches_to_tavily(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"results": [{"url": "https://example.com", "raw_content": "Extracted content", "title": "Page"}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("tools.web_tools._get_backend", return_value="tavily"), \
|
||||
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
|
||||
patch("tools.web_tools.httpx.post", return_value=mock_response), \
|
||||
patch("tools.web_tools.process_content_with_llm", return_value=None):
|
||||
from tools.web_tools import web_extract_tool
|
||||
result = json.loads(asyncio.get_event_loop().run_until_complete(
|
||||
web_extract_tool(["https://example.com"], use_llm_processing=False)
|
||||
))
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["url"] == "https://example.com"
|
||||
|
||||
|
||||
# ─── web_crawl_tool (Tavily dispatch) ─────────────────────────────────────────
|
||||
|
||||
class TestWebCrawlTavily:
|
||||
"""Test web_crawl_tool dispatch to Tavily."""
|
||||
|
||||
def test_crawl_dispatches_to_tavily(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"results": [
|
||||
{"url": "https://example.com/page1", "raw_content": "Page 1 content", "title": "Page 1"},
|
||||
{"url": "https://example.com/page2", "raw_content": "Page 2 content", "title": "Page 2"},
|
||||
]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("tools.web_tools._get_backend", return_value="tavily"), \
|
||||
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
|
||||
patch("tools.web_tools.httpx.post", return_value=mock_response), \
|
||||
patch("tools.web_tools.check_website_access", return_value=None), \
|
||||
patch("tools.interrupt.is_interrupted", return_value=False):
|
||||
from tools.web_tools import web_crawl_tool
|
||||
result = json.loads(asyncio.get_event_loop().run_until_complete(
|
||||
web_crawl_tool("https://example.com", use_llm_processing=False)
|
||||
))
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 2
|
||||
assert result["results"][0]["title"] == "Page 1"
|
||||
|
||||
def test_crawl_sends_instructions(self):
|
||||
"""Instructions are included in the Tavily crawl payload."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"results": []}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("tools.web_tools._get_backend", return_value="tavily"), \
|
||||
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
|
||||
patch("tools.web_tools.httpx.post", return_value=mock_response) as mock_post, \
|
||||
patch("tools.web_tools.check_website_access", return_value=None), \
|
||||
patch("tools.interrupt.is_interrupted", return_value=False):
|
||||
from tools.web_tools import web_crawl_tool
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
web_crawl_tool("https://example.com", instructions="Find docs", use_llm_processing=False)
|
||||
)
|
||||
call_kwargs = mock_post.call_args
|
||||
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
||||
assert payload["instructions"] == "Find docs"
|
||||
assert payload["url"] == "https://example.com"
|
||||
@@ -555,6 +555,11 @@ def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
|
||||
session_info = provider.create_session(task_id)
|
||||
|
||||
with _cleanup_lock:
|
||||
# Double-check: another thread may have created a session while we
|
||||
# were doing the network call. Use the existing one to avoid leaking
|
||||
# orphan cloud sessions.
|
||||
if task_id in _active_sessions:
|
||||
return _active_sessions[task_id]
|
||||
_active_sessions[task_id] = session_info
|
||||
|
||||
return session_info
|
||||
|
||||
@@ -23,11 +23,13 @@ Design:
|
||||
- Frozen snapshot pattern: system prompt is stable, tool responses show live state
|
||||
"""
|
||||
|
||||
import fcntl
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
@@ -120,14 +122,43 @@ class MemoryStore:
|
||||
"user": self._render_block("user", self.user_entries),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _file_lock(path: Path):
|
||||
"""Acquire an exclusive file lock for read-modify-write safety.
|
||||
|
||||
Uses a separate .lock file so the memory file itself can still be
|
||||
atomically replaced via os.replace().
|
||||
"""
|
||||
lock_path = path.with_suffix(path.suffix + ".lock")
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = open(lock_path, "w")
|
||||
try:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
yield
|
||||
finally:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
fd.close()
|
||||
|
||||
@staticmethod
|
||||
def _path_for(target: str) -> Path:
|
||||
if target == "user":
|
||||
return MEMORY_DIR / "USER.md"
|
||||
return MEMORY_DIR / "MEMORY.md"
|
||||
|
||||
def _reload_target(self, target: str):
|
||||
"""Re-read entries from disk into in-memory state.
|
||||
|
||||
Called under file lock to get the latest state before mutating.
|
||||
"""
|
||||
fresh = self._read_file(self._path_for(target))
|
||||
fresh = list(dict.fromkeys(fresh)) # deduplicate
|
||||
self._set_entries(target, fresh)
|
||||
|
||||
def save_to_disk(self, target: str):
|
||||
"""Persist entries to the appropriate file. Called after every mutation."""
|
||||
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if target == "memory":
|
||||
self._write_file(MEMORY_DIR / "MEMORY.md", self.memory_entries)
|
||||
elif target == "user":
|
||||
self._write_file(MEMORY_DIR / "USER.md", self.user_entries)
|
||||
self._write_file(self._path_for(target), self._entries_for(target))
|
||||
|
||||
def _entries_for(self, target: str) -> List[str]:
|
||||
if target == "user":
|
||||
@@ -162,33 +193,37 @@ class MemoryStore:
|
||||
if scan_error:
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
entries = self._entries_for(target)
|
||||
limit = self._char_limit(target)
|
||||
with self._file_lock(self._path_for(target)):
|
||||
# Re-read from disk under lock to pick up writes from other sessions
|
||||
self._reload_target(target)
|
||||
|
||||
# Reject exact duplicates
|
||||
if content in entries:
|
||||
return self._success_response(target, "Entry already exists (no duplicate added).")
|
||||
entries = self._entries_for(target)
|
||||
limit = self._char_limit(target)
|
||||
|
||||
# Calculate what the new total would be
|
||||
new_entries = entries + [content]
|
||||
new_total = len(ENTRY_DELIMITER.join(new_entries))
|
||||
# Reject exact duplicates
|
||||
if content in entries:
|
||||
return self._success_response(target, "Entry already exists (no duplicate added).")
|
||||
|
||||
if new_total > limit:
|
||||
current = self._char_count(target)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Memory at {current:,}/{limit:,} chars. "
|
||||
f"Adding this entry ({len(content)} chars) would exceed the limit. "
|
||||
f"Replace or remove existing entries first."
|
||||
),
|
||||
"current_entries": entries,
|
||||
"usage": f"{current:,}/{limit:,}",
|
||||
}
|
||||
# Calculate what the new total would be
|
||||
new_entries = entries + [content]
|
||||
new_total = len(ENTRY_DELIMITER.join(new_entries))
|
||||
|
||||
entries.append(content)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
if new_total > limit:
|
||||
current = self._char_count(target)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Memory at {current:,}/{limit:,} chars. "
|
||||
f"Adding this entry ({len(content)} chars) would exceed the limit. "
|
||||
f"Replace or remove existing entries first."
|
||||
),
|
||||
"current_entries": entries,
|
||||
"usage": f"{current:,}/{limit:,}",
|
||||
}
|
||||
|
||||
entries.append(content)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry added.")
|
||||
|
||||
@@ -206,44 +241,47 @@ class MemoryStore:
|
||||
if scan_error:
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), operate on the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), operate on the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to replace just the first
|
||||
|
||||
idx = matches[0][0]
|
||||
limit = self._char_limit(target)
|
||||
|
||||
# Check that replacement doesn't blow the budget
|
||||
test_entries = entries.copy()
|
||||
test_entries[idx] = new_content
|
||||
new_total = len(ENTRY_DELIMITER.join(test_entries))
|
||||
|
||||
if new_total > limit:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
"error": (
|
||||
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
|
||||
f"Shorten the new content or remove other entries first."
|
||||
),
|
||||
}
|
||||
# All identical -- safe to replace just the first
|
||||
|
||||
idx = matches[0][0]
|
||||
limit = self._char_limit(target)
|
||||
|
||||
# Check that replacement doesn't blow the budget
|
||||
test_entries = entries.copy()
|
||||
test_entries[idx] = new_content
|
||||
new_total = len(ENTRY_DELIMITER.join(test_entries))
|
||||
|
||||
if new_total > limit:
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
|
||||
f"Shorten the new content or remove other entries first."
|
||||
),
|
||||
}
|
||||
|
||||
entries[idx] = new_content
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
entries[idx] = new_content
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry replaced.")
|
||||
|
||||
@@ -253,28 +291,31 @@ class MemoryStore:
|
||||
if not old_text:
|
||||
return {"success": False, "error": "old_text cannot be empty."}
|
||||
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), remove the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to remove just the first
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
idx = matches[0][0]
|
||||
entries.pop(idx)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), remove the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to remove just the first
|
||||
|
||||
idx = matches[0][0]
|
||||
entries.pop(idx)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry removed.")
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ import os
|
||||
import re
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
import httpx
|
||||
from firecrawl import Firecrawl
|
||||
from agent.auxiliary_client import async_call_llm
|
||||
from tools.debug_helpers import DebugSession
|
||||
@@ -73,11 +74,14 @@ def _get_backend() -> str:
|
||||
keys manually without running setup.
|
||||
"""
|
||||
configured = _load_web_config().get("backend", "").lower().strip()
|
||||
if configured in ("parallel", "firecrawl"):
|
||||
if configured in ("parallel", "firecrawl", "tavily"):
|
||||
return configured
|
||||
# Fallback for manual / legacy config — use whichever key is present.
|
||||
has_firecrawl = bool(os.getenv("FIRECRAWL_API_KEY") or os.getenv("FIRECRAWL_API_URL"))
|
||||
has_parallel = bool(os.getenv("PARALLEL_API_KEY"))
|
||||
has_tavily = bool(os.getenv("TAVILY_API_KEY"))
|
||||
if has_tavily and not has_firecrawl and not has_parallel:
|
||||
return "tavily"
|
||||
if has_parallel and not has_firecrawl:
|
||||
return "parallel"
|
||||
# Default to firecrawl (backward compat, or when both are set)
|
||||
@@ -155,6 +159,88 @@ def _get_async_parallel_client():
|
||||
_async_parallel_client = AsyncParallel(api_key=api_key)
|
||||
return _async_parallel_client
|
||||
|
||||
# ─── Tavily Client ───────────────────────────────────────────────────────────
|
||||
|
||||
_TAVILY_BASE_URL = "https://api.tavily.com"
|
||||
|
||||
|
||||
def _tavily_request(endpoint: str, payload: dict) -> dict:
|
||||
"""Send a POST request to the Tavily API.
|
||||
|
||||
Auth is provided via ``api_key`` in the JSON body (no header-based auth).
|
||||
Raises ``ValueError`` if ``TAVILY_API_KEY`` is not set.
|
||||
"""
|
||||
api_key = os.getenv("TAVILY_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"TAVILY_API_KEY environment variable not set. "
|
||||
"Get your API key at https://app.tavily.com/home"
|
||||
)
|
||||
payload["api_key"] = api_key
|
||||
url = f"{_TAVILY_BASE_URL}/{endpoint.lstrip('/')}"
|
||||
logger.info("Tavily %s request to %s", endpoint, url)
|
||||
response = httpx.post(url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _normalize_tavily_search_results(response: dict) -> dict:
|
||||
"""Normalize Tavily /search response to the standard web search format.
|
||||
|
||||
Tavily returns ``{results: [{title, url, content, score, ...}]}``.
|
||||
We map to ``{success, data: {web: [{title, url, description, position}]}}``.
|
||||
"""
|
||||
web_results = []
|
||||
for i, result in enumerate(response.get("results", [])):
|
||||
web_results.append({
|
||||
"title": result.get("title", ""),
|
||||
"url": result.get("url", ""),
|
||||
"description": result.get("content", ""),
|
||||
"position": i + 1,
|
||||
})
|
||||
return {"success": True, "data": {"web": web_results}}
|
||||
|
||||
|
||||
def _normalize_tavily_documents(response: dict, fallback_url: str = "") -> List[Dict[str, Any]]:
|
||||
"""Normalize Tavily /extract or /crawl response to the standard document format.
|
||||
|
||||
Maps results to ``{url, title, content, raw_content, metadata}`` and
|
||||
includes any ``failed_results`` / ``failed_urls`` as error entries.
|
||||
"""
|
||||
documents: List[Dict[str, Any]] = []
|
||||
for result in response.get("results", []):
|
||||
url = result.get("url", fallback_url)
|
||||
raw = result.get("raw_content", "") or result.get("content", "")
|
||||
documents.append({
|
||||
"url": url,
|
||||
"title": result.get("title", ""),
|
||||
"content": raw,
|
||||
"raw_content": raw,
|
||||
"metadata": {"sourceURL": url, "title": result.get("title", "")},
|
||||
})
|
||||
# Handle failed results
|
||||
for fail in response.get("failed_results", []):
|
||||
documents.append({
|
||||
"url": fail.get("url", fallback_url),
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": fail.get("error", "extraction failed"),
|
||||
"metadata": {"sourceURL": fail.get("url", fallback_url)},
|
||||
})
|
||||
for fail_url in response.get("failed_urls", []):
|
||||
url_str = fail_url if isinstance(fail_url, str) else str(fail_url)
|
||||
documents.append({
|
||||
"url": url_str,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": "extraction failed",
|
||||
"metadata": {"sourceURL": url_str},
|
||||
})
|
||||
return documents
|
||||
|
||||
|
||||
DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION = 5000
|
||||
|
||||
# Allow per-task override via env var
|
||||
@@ -639,6 +725,22 @@ def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
_debug.save()
|
||||
return result_json
|
||||
|
||||
if backend == "tavily":
|
||||
logger.info("Tavily search: '%s' (limit: %d)", query, limit)
|
||||
raw = _tavily_request("search", {
|
||||
"query": query,
|
||||
"max_results": min(limit, 20),
|
||||
"include_raw_content": False,
|
||||
"include_images": False,
|
||||
})
|
||||
response_data = _normalize_tavily_search_results(raw)
|
||||
debug_call_data["results_count"] = len(response_data.get("data", {}).get("web", []))
|
||||
result_json = json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
debug_call_data["final_response_size"] = len(result_json)
|
||||
_debug.log_call("web_search_tool", debug_call_data)
|
||||
_debug.save()
|
||||
return result_json
|
||||
|
||||
logger.info("Searching the web for: '%s' (limit: %d)", query, limit)
|
||||
|
||||
response = _get_firecrawl_client().search(
|
||||
@@ -763,6 +865,13 @@ async def web_extract_tool(
|
||||
|
||||
if backend == "parallel":
|
||||
results = await _parallel_extract(urls)
|
||||
elif backend == "tavily":
|
||||
logger.info("Tavily extract: %d URL(s)", len(urls))
|
||||
raw = _tavily_request("extract", {
|
||||
"urls": urls,
|
||||
"include_images": False,
|
||||
})
|
||||
results = _normalize_tavily_documents(raw, fallback_url=urls[0] if urls else "")
|
||||
else:
|
||||
# ── Firecrawl extraction ──
|
||||
# Determine requested formats for Firecrawl v2
|
||||
@@ -1055,6 +1164,83 @@ async def web_crawl_tool(
|
||||
}
|
||||
|
||||
try:
|
||||
backend = _get_backend()
|
||||
|
||||
# Tavily supports crawl via its /crawl endpoint
|
||||
if backend == "tavily":
|
||||
# Ensure URL has protocol
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
url = f'https://{url}'
|
||||
|
||||
# Website policy check
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
logger.info("Blocked web_crawl for %s by rule %s", blocked["host"], blocked["rule"])
|
||||
return json.dumps({"results": [{"url": url, "title": "", "content": "", "error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}}]}, ensure_ascii=False)
|
||||
|
||||
from tools.interrupt import is_interrupted as _is_int
|
||||
if _is_int():
|
||||
return json.dumps({"error": "Interrupted", "success": False})
|
||||
|
||||
logger.info("Tavily crawl: %s", url)
|
||||
payload: Dict[str, Any] = {
|
||||
"url": url,
|
||||
"limit": 20,
|
||||
"extract_depth": depth,
|
||||
}
|
||||
if instructions:
|
||||
payload["instructions"] = instructions
|
||||
raw = _tavily_request("crawl", payload)
|
||||
results = _normalize_tavily_documents(raw, fallback_url=url)
|
||||
|
||||
response = {"results": results}
|
||||
# Fall through to the shared LLM processing and trimming below
|
||||
# (skip the Firecrawl-specific crawl logic)
|
||||
pages_crawled = len(response.get('results', []))
|
||||
logger.info("Crawled %d pages", pages_crawled)
|
||||
debug_call_data["pages_crawled"] = pages_crawled
|
||||
debug_call_data["original_response_size"] = len(json.dumps(response))
|
||||
|
||||
# Process each result with LLM if enabled
|
||||
if use_llm_processing:
|
||||
logger.info("Processing crawled content with LLM (parallel)...")
|
||||
debug_call_data["processing_applied"].append("llm_processing")
|
||||
|
||||
async def _process_tavily_crawl(result):
|
||||
page_url = result.get('url', 'Unknown URL')
|
||||
title = result.get('title', '')
|
||||
content = result.get('content', '')
|
||||
if not content:
|
||||
return result, None, "no_content"
|
||||
original_size = len(content)
|
||||
processed = await process_content_with_llm(content, page_url, title, model, min_length)
|
||||
if processed:
|
||||
result['raw_content'] = content
|
||||
result['content'] = processed
|
||||
metrics = {"url": page_url, "original_size": original_size, "processed_size": len(processed),
|
||||
"compression_ratio": len(processed) / original_size if original_size else 1.0, "model_used": model}
|
||||
return result, metrics, "processed"
|
||||
metrics = {"url": page_url, "original_size": original_size, "processed_size": original_size,
|
||||
"compression_ratio": 1.0, "model_used": None, "reason": "content_too_short"}
|
||||
return result, metrics, "too_short"
|
||||
|
||||
tasks = [_process_tavily_crawl(r) for r in response.get('results', [])]
|
||||
processed_results = await asyncio.gather(*tasks)
|
||||
for result, metrics, status in processed_results:
|
||||
if status == "processed":
|
||||
debug_call_data["compression_metrics"].append(metrics)
|
||||
debug_call_data["pages_processed_with_llm"] += 1
|
||||
|
||||
trimmed_results = [{"url": r.get("url", ""), "title": r.get("title", ""), "content": r.get("content", ""), "error": r.get("error"),
|
||||
**({ "blocked_by_policy": r["blocked_by_policy"]} if "blocked_by_policy" in r else {})} for r in response.get("results", [])]
|
||||
result_json = json.dumps({"results": trimmed_results}, indent=2, ensure_ascii=False)
|
||||
cleaned_result = clean_base64_images(result_json)
|
||||
debug_call_data["final_response_size"] = len(cleaned_result)
|
||||
_debug.log_call("web_crawl_tool", debug_call_data)
|
||||
_debug.save()
|
||||
return cleaned_result
|
||||
|
||||
# web_crawl requires Firecrawl — Parallel has no crawl API
|
||||
if not (os.getenv("FIRECRAWL_API_KEY") or os.getenv("FIRECRAWL_API_URL")):
|
||||
return json.dumps({
|
||||
@@ -1335,11 +1521,12 @@ def check_firecrawl_api_key() -> bool:
|
||||
|
||||
|
||||
def check_web_api_key() -> bool:
|
||||
"""Check if any web backend API key is available (Parallel or Firecrawl)."""
|
||||
"""Check if any web backend API key is available (Parallel, Firecrawl, or Tavily)."""
|
||||
return bool(
|
||||
os.getenv("PARALLEL_API_KEY")
|
||||
or os.getenv("FIRECRAWL_API_KEY")
|
||||
or os.getenv("FIRECRAWL_API_URL")
|
||||
or os.getenv("TAVILY_API_KEY")
|
||||
)
|
||||
|
||||
|
||||
@@ -1377,11 +1564,13 @@ if __name__ == "__main__":
|
||||
print(f"✅ Web backend: {backend}")
|
||||
if backend == "parallel":
|
||||
print(" Using Parallel API (https://parallel.ai)")
|
||||
elif backend == "tavily":
|
||||
print(" Using Tavily API (https://tavily.com)")
|
||||
else:
|
||||
print(" Using Firecrawl API (https://firecrawl.dev)")
|
||||
else:
|
||||
print("❌ No web search backend configured")
|
||||
print("Set PARALLEL_API_KEY (https://parallel.ai) or FIRECRAWL_API_KEY (https://firecrawl.dev)")
|
||||
print("Set PARALLEL_API_KEY, TAVILY_API_KEY, or FIRECRAWL_API_KEY")
|
||||
|
||||
if not nous_available:
|
||||
print("❌ No auxiliary model available for LLM content processing")
|
||||
@@ -1491,7 +1680,7 @@ registry.register(
|
||||
schema=WEB_SEARCH_SCHEMA,
|
||||
handler=lambda args, **kw: web_search_tool(args.get("query", ""), limit=5),
|
||||
check_fn=check_web_api_key,
|
||||
requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY"],
|
||||
requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "TAVILY_API_KEY"],
|
||||
emoji="🔍",
|
||||
)
|
||||
registry.register(
|
||||
@@ -1501,7 +1690,7 @@ registry.register(
|
||||
handler=lambda args, **kw: web_extract_tool(
|
||||
args.get("urls", [])[:5] if isinstance(args.get("urls"), list) else [], "markdown"),
|
||||
check_fn=check_web_api_key,
|
||||
requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY"],
|
||||
requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "TAVILY_API_KEY"],
|
||||
is_async=True,
|
||||
emoji="📄",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user