Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 415043315f | |||
| 98eb32f39a | |||
| 2df306e6cd | |||
| 79a5f03f92 | |||
| 527ca7d238 | |||
| b11e53e34f | |||
| 1e7a598bac | |||
| 3eddabf53b | |||
| 971542d254 | |||
| 4a95029e6c | |||
| 432614591a |
@@ -357,7 +357,7 @@ def _common_betas_for_base_url(base_url: str | None) -> list[str]:
|
||||
return _COMMON_BETAS
|
||||
|
||||
|
||||
def build_anthropic_client(api_key: str, base_url: str = None, timeout: float = None):
|
||||
def build_anthropic_client(api_key: str, base_url: str = None, timeout: Optional[float] = None):
|
||||
"""Create an Anthropic client, auto-detecting setup-tokens vs API keys.
|
||||
|
||||
If *timeout* is provided it overrides the default 900s read timeout. The
|
||||
|
||||
@@ -41,10 +41,13 @@ import threading
|
||||
import time
|
||||
from pathlib import Path # noqa: F401 — used by test mocks
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent.gemini_native_adapter import GeminiNativeClient
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
from hermes_cli.config import get_hermes_home
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
@@ -810,7 +813,11 @@ def _read_codex_access_token() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
# TODO(refactor): This function has messy types and duplicated logic (pool vs direct creds).
|
||||
# Ideal fix: (1) define an AuxiliaryClient Protocol both OpenAI/GeminiNativeClient satisfy,
|
||||
# (2) return a NamedTuple or dataclass instead of raw tuple, (3) extract the repeated
|
||||
# Gemini/Kimi/Copilot client-building into a helper.
|
||||
def _resolve_api_key_provider() -> Tuple[Optional[Union[OpenAI, "GeminiNativeClient"]], Optional[str]]:
|
||||
"""Try each API-key provider in PROVIDER_REGISTRY order.
|
||||
|
||||
Returns (client, model) for the first provider with usable runtime
|
||||
|
||||
@@ -29,6 +29,7 @@ from hermes_cli.auth import (
|
||||
_save_auth_store,
|
||||
_save_provider_state,
|
||||
read_credential_pool,
|
||||
read_provider_credentials,
|
||||
write_credential_pool,
|
||||
)
|
||||
|
||||
@@ -321,7 +322,7 @@ def get_custom_provider_pool_key(base_url: str) -> Optional[str]:
|
||||
|
||||
def list_custom_pool_providers() -> List[str]:
|
||||
"""Return all 'custom:*' pool keys that have entries in auth.json."""
|
||||
pool_data = read_credential_pool(None)
|
||||
pool_data = read_credential_pool()
|
||||
return sorted(
|
||||
key for key in pool_data
|
||||
if key.startswith(CUSTOM_POOL_PREFIX)
|
||||
@@ -875,6 +876,20 @@ class CredentialPool:
|
||||
self._current_id = None
|
||||
return removed
|
||||
|
||||
def remove_entry(self, entry_id: str) -> Optional[PooledCredential]:
|
||||
for idx, entry in enumerate(self._entries):
|
||||
if entry.id == entry_id:
|
||||
removed = self._entries.pop(idx)
|
||||
self._entries = [
|
||||
replace(e, priority=new_priority)
|
||||
for new_priority, e in enumerate(self._entries)
|
||||
]
|
||||
self._persist()
|
||||
if self._current_id == removed.id:
|
||||
self._current_id = None
|
||||
return removed
|
||||
return None
|
||||
|
||||
def resolve_target(self, target: Any) -> Tuple[Optional[int], Optional[PooledCredential], Optional[str]]:
|
||||
raw = str(target or "").strip()
|
||||
if not raw:
|
||||
@@ -1325,7 +1340,7 @@ def _seed_custom_pool(pool_key: str, entries: List[PooledCredential]) -> Tuple[b
|
||||
|
||||
def load_pool(provider: str) -> CredentialPool:
|
||||
provider = (provider or "").strip().lower()
|
||||
raw_entries = read_credential_pool(provider)
|
||||
raw_entries = read_provider_credentials(provider)
|
||||
entries = [PooledCredential.from_dict(provider, payload) for payload in raw_entries]
|
||||
|
||||
if provider.startswith(CUSTOM_POOL_PREFIX):
|
||||
|
||||
@@ -729,6 +729,7 @@ class KawaiiSpinner:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
frame = self.spinner_frames[self.frame_idx % len(self.spinner_frames)]
|
||||
assert self.start_time is not None # start() sets it before thread starts
|
||||
elapsed = time.time() - self.start_time
|
||||
if wings:
|
||||
left, right = wings[self.frame_idx % len(wings)]
|
||||
|
||||
@@ -455,7 +455,8 @@ def parse_qualified_name(name: str) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
if ":" not in name:
|
||||
return None, name
|
||||
return tuple(name.split(":", 1)) # type: ignore[return-value]
|
||||
ns, bare = name.split(":", 1)
|
||||
return ns, bare
|
||||
|
||||
|
||||
def is_valid_namespace(candidate: Optional[str]) -> bool:
|
||||
|
||||
@@ -30,7 +30,7 @@ from urllib.parse import unquote, urlparse
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any, Optional, TypedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,6 +84,34 @@ _project_env = Path(__file__).parent / '.env'
|
||||
load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
|
||||
|
||||
class _ModelPickerState(TypedDict, total=False):
|
||||
stage: str
|
||||
providers: List[Dict[str, Any]]
|
||||
selected: int
|
||||
current_model: str
|
||||
current_provider: str
|
||||
user_provs: Optional[Dict[str, Any]]
|
||||
custom_provs: Optional[Dict[str, Any]]
|
||||
provider_data: Dict[str, Any]
|
||||
model_list: List[str]
|
||||
|
||||
|
||||
class _ApprovalState(TypedDict, total=False):
|
||||
command: str
|
||||
description: str
|
||||
choices: List[str]
|
||||
selected: int
|
||||
response_queue: "queue.Queue[str]"
|
||||
show_full: bool
|
||||
|
||||
|
||||
class _ClarifyState(TypedDict, total=False):
|
||||
question: str
|
||||
choices: List[str]
|
||||
selected: int
|
||||
response_queue: "queue.Queue[str]"
|
||||
|
||||
|
||||
_REASONING_TAGS = (
|
||||
"REASONING_SCRATCHPAD",
|
||||
"think",
|
||||
@@ -1728,7 +1756,7 @@ def _parse_skills_argument(skills: str | list[str] | tuple[str, ...] | None) ->
|
||||
return parsed
|
||||
|
||||
|
||||
def save_config_value(key_path: str, value: any) -> bool:
|
||||
def save_config_value(key_path: str, value: Any) -> bool:
|
||||
"""
|
||||
Save a value to the active config file at the specified key path.
|
||||
|
||||
@@ -2065,16 +2093,16 @@ class HermesCLI:
|
||||
self._interrupt_queue = queue.Queue()
|
||||
self._should_exit = False
|
||||
self._last_ctrl_c_time = 0
|
||||
self._clarify_state = None
|
||||
self._clarify_state: Optional[_ClarifyState] = None
|
||||
self._clarify_freetext = False
|
||||
self._clarify_deadline = 0
|
||||
self._sudo_state = None
|
||||
self._sudo_deadline = 0
|
||||
self._modal_input_snapshot = None
|
||||
self._approval_state = None
|
||||
self._approval_state: Optional[_ApprovalState] = None
|
||||
self._approval_deadline = 0
|
||||
self._approval_lock = threading.Lock()
|
||||
self._model_picker_state = None
|
||||
self._model_picker_state: Optional[_ModelPickerState] = None
|
||||
self._secret_state = None
|
||||
self._secret_deadline = 0
|
||||
self._spinner_text: str = "" # thinking spinner text for TUI
|
||||
@@ -7156,7 +7184,7 @@ class HermesCLI:
|
||||
logging.getLogger(noisy).setLevel(logging.WARNING)
|
||||
else:
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
for quiet_logger in ('tools', 'run_agent', 'trajectory_compressor', 'cron', 'hermes_cli'):
|
||||
for quiet_logger in ('tools', 'run_agent', 'scripts.trajectory_compressor', 'cron', 'hermes_cli'):
|
||||
logging.getLogger(quiet_logger).setLevel(logging.ERROR)
|
||||
|
||||
def _show_insights(self, command: str = "/insights"):
|
||||
|
||||
+3
-2
@@ -439,8 +439,9 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
delivery_errors.append(msg)
|
||||
continue
|
||||
|
||||
if result and result.get("error"):
|
||||
msg = f"delivery error: {result['error']}"
|
||||
error = result.get("error") if result else None
|
||||
if error:
|
||||
msg = f"delivery error: {error}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
delivery_errors.append(msg)
|
||||
continue
|
||||
|
||||
@@ -29,7 +29,7 @@ echo "📝 Logging to: $LOG_FILE"
|
||||
# Point to the example dataset in this directory
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
|
||||
python batch_runner.py \
|
||||
python scripts/batch_runner.py \
|
||||
--dataset_file="$SCRIPT_DIR/example_browser_tasks.jsonl" \
|
||||
--batch_size=5 \
|
||||
--run_name="browser_tasks_example" \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# Generates tool-calling trajectories for multi-step web research tasks.
|
||||
#
|
||||
# Usage:
|
||||
# python batch_runner.py \
|
||||
# python scripts/batch_runner.py \
|
||||
# --config datagen-config-examples/web_research.yaml \
|
||||
# --run_name web_research_v1
|
||||
|
||||
|
||||
@@ -18,7 +18,10 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.budget_config import BudgetConfig
|
||||
|
||||
from model_tools import handle_function_call
|
||||
from tools.terminal_tool import get_active_env
|
||||
|
||||
@@ -32,14 +32,7 @@ import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from aiohttp import web
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
web = None # type: ignore[assignment]
|
||||
|
||||
from aiohttp import web
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
@@ -270,12 +263,6 @@ def _multimodal_validation_error(exc: ValueError, *, param: str) -> "web.Respons
|
||||
status=400,
|
||||
)
|
||||
|
||||
|
||||
def check_api_server_requirements() -> bool:
|
||||
"""Check if API server dependencies are available."""
|
||||
return AIOHTTP_AVAILABLE
|
||||
|
||||
|
||||
class ResponseStore:
|
||||
"""
|
||||
SQLite-backed LRU store for Responses API state.
|
||||
@@ -391,30 +378,26 @@ _CORS_HEADERS = {
|
||||
}
|
||||
|
||||
|
||||
if AIOHTTP_AVAILABLE:
|
||||
@web.middleware
|
||||
async def cors_middleware(request, handler):
|
||||
"""Add CORS headers for explicitly allowed origins; handle OPTIONS preflight."""
|
||||
adapter = request.app.get("api_server_adapter")
|
||||
origin = request.headers.get("Origin", "")
|
||||
cors_headers = None
|
||||
if adapter is not None:
|
||||
if not adapter._origin_allowed(origin):
|
||||
return web.Response(status=403)
|
||||
cors_headers = adapter._cors_headers_for_origin(origin)
|
||||
@web.middleware
|
||||
async def cors_middleware(request, handler):
|
||||
"""Add CORS headers for explicitly allowed origins; handle OPTIONS preflight."""
|
||||
adapter = request.app.get("api_server_adapter")
|
||||
origin = request.headers.get("Origin", "")
|
||||
cors_headers = None
|
||||
if adapter is not None:
|
||||
if not adapter._origin_allowed(origin):
|
||||
return web.Response(status=403)
|
||||
cors_headers = adapter._cors_headers_for_origin(origin)
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
if cors_headers is None:
|
||||
return web.Response(status=403)
|
||||
return web.Response(status=200, headers=cors_headers)
|
||||
|
||||
response = await handler(request)
|
||||
if cors_headers is not None:
|
||||
response.headers.update(cors_headers)
|
||||
return response
|
||||
else:
|
||||
cors_middleware = None # type: ignore[assignment]
|
||||
if request.method == "OPTIONS":
|
||||
if cors_headers is None:
|
||||
return web.Response(status=403)
|
||||
return web.Response(status=200, headers=cors_headers)
|
||||
|
||||
response = await handler(request)
|
||||
if cors_headers is not None:
|
||||
response.headers.update(cors_headers)
|
||||
return response
|
||||
|
||||
def _openai_error(message: str, err_type: str = "invalid_request_error", param: str = None, code: str = None) -> Dict[str, Any]:
|
||||
"""OpenAI-style error envelope."""
|
||||
@@ -428,21 +411,18 @@ def _openai_error(message: str, err_type: str = "invalid_request_error", param:
|
||||
}
|
||||
|
||||
|
||||
if AIOHTTP_AVAILABLE:
|
||||
@web.middleware
|
||||
async def body_limit_middleware(request, handler):
|
||||
"""Reject overly large request bodies early based on Content-Length."""
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
cl = request.headers.get("Content-Length")
|
||||
if cl is not None:
|
||||
try:
|
||||
if int(cl) > MAX_REQUEST_BYTES:
|
||||
return web.json_response(_openai_error("Request body too large.", code="body_too_large"), status=413)
|
||||
except ValueError:
|
||||
return web.json_response(_openai_error("Invalid Content-Length header.", code="invalid_content_length"), status=400)
|
||||
return await handler(request)
|
||||
else:
|
||||
body_limit_middleware = None # type: ignore[assignment]
|
||||
@web.middleware
|
||||
async def body_limit_middleware(request, handler):
|
||||
"""Reject overly large request bodies early based on Content-Length."""
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
cl = request.headers.get("Content-Length")
|
||||
if cl is not None:
|
||||
try:
|
||||
if int(cl) > MAX_REQUEST_BYTES:
|
||||
return web.json_response(_openai_error("Request body too large.", code="body_too_large"), status=413)
|
||||
except ValueError:
|
||||
return web.json_response(_openai_error("Invalid Content-Length header.", code="invalid_content_length"), status=400)
|
||||
return await handler(request)
|
||||
|
||||
_SECURITY_HEADERS = {
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
@@ -450,16 +430,13 @@ _SECURITY_HEADERS = {
|
||||
}
|
||||
|
||||
|
||||
if AIOHTTP_AVAILABLE:
|
||||
@web.middleware
|
||||
async def security_headers_middleware(request, handler):
|
||||
"""Add security headers to all responses (including errors)."""
|
||||
response = await handler(request)
|
||||
for k, v in _SECURITY_HEADERS.items():
|
||||
response.headers.setdefault(k, v)
|
||||
return response
|
||||
else:
|
||||
security_headers_middleware = None # type: ignore[assignment]
|
||||
@web.middleware
|
||||
async def security_headers_middleware(request, handler):
|
||||
"""Add security headers to all responses (including errors)."""
|
||||
response = await handler(request)
|
||||
for k, v in _SECURITY_HEADERS.items():
|
||||
response.headers.setdefault(k, v)
|
||||
return response
|
||||
|
||||
|
||||
class _IdempotencyCache:
|
||||
@@ -804,7 +781,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
],
|
||||
})
|
||||
|
||||
async def _handle_chat_completions(self, request: "web.Request") -> "web.Response":
|
||||
async def _handle_chat_completions(self, request: "web.Request") -> "web.StreamResponse":
|
||||
"""POST /v1/chat/completions — OpenAI Chat Completions format."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
@@ -1588,7 +1565,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
return response
|
||||
|
||||
async def _handle_responses(self, request: "web.Request") -> "web.Response":
|
||||
async def _handle_responses(self, request: "web.Request") -> "web.StreamResponse":
|
||||
"""POST /v1/responses — OpenAI Responses API format."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
@@ -2482,10 +2459,6 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Start the aiohttp web server."""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
logger.warning("[%s] aiohttp not installed", self.name)
|
||||
return False
|
||||
|
||||
try:
|
||||
mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware) if mw is not None]
|
||||
self._app = web.Application(middlewares=mws)
|
||||
|
||||
+21
-20
@@ -187,16 +187,14 @@ def proxy_kwargs_for_bot(proxy_url: str | None) -> dict:
|
||||
if proxy_url.lower().startswith("socks"):
|
||||
try:
|
||||
from aiohttp_socks import ProxyConnector
|
||||
|
||||
connector = ProxyConnector.from_url(proxy_url, rdns=True)
|
||||
return {"connector": connector}
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"aiohttp_socks not installed — SOCKS proxy %s ignored. "
|
||||
"Run: pip install aiohttp-socks",
|
||||
proxy_url,
|
||||
)
|
||||
return {}
|
||||
raise ImportError(
|
||||
"aiohttp-socks is required for SOCKS proxy support. "
|
||||
"Install with: pip install hermes-agent[messaging]"
|
||||
) from None
|
||||
|
||||
connector = ProxyConnector.from_url(proxy_url, rdns=True)
|
||||
return {"connector": connector}
|
||||
return {"proxy": proxy_url}
|
||||
|
||||
|
||||
@@ -220,16 +218,14 @@ def proxy_kwargs_for_aiohttp(proxy_url: str | None) -> tuple[dict, dict]:
|
||||
if proxy_url.lower().startswith("socks"):
|
||||
try:
|
||||
from aiohttp_socks import ProxyConnector
|
||||
|
||||
connector = ProxyConnector.from_url(proxy_url, rdns=True)
|
||||
return {"connector": connector}, {}
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"aiohttp_socks not installed — SOCKS proxy %s ignored. "
|
||||
"Run: pip install aiohttp-socks",
|
||||
proxy_url,
|
||||
)
|
||||
return {}, {}
|
||||
raise ImportError(
|
||||
"aiohttp-socks is required for SOCKS proxy support. "
|
||||
"Install with: pip install hermes-agent[messaging]"
|
||||
) from None
|
||||
|
||||
connector = ProxyConnector.from_url(proxy_url, rdns=True)
|
||||
return {"connector": connector}, {}
|
||||
return {}, {"proxy": proxy_url}
|
||||
|
||||
|
||||
@@ -428,6 +424,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
|
||||
def cleanup_image_cache(max_age_hours: int = 24) -> int:
|
||||
@@ -542,6 +539,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1831,8 +1829,11 @@ class BasePlatformAdapter(ABC):
|
||||
try:
|
||||
await self._run_processing_hook("on_processing_start", event)
|
||||
|
||||
# Call the handler (this can take a while with tool calls)
|
||||
response = await self._message_handler(event)
|
||||
handler = self._message_handler
|
||||
if handler is None:
|
||||
return
|
||||
|
||||
response = await handler(event)
|
||||
|
||||
# Send response if any. A None/empty response is normal when
|
||||
# streaming already delivered the text (already_sent=True) or
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -377,7 +377,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
||||
payload = {
|
||||
"addresses": [address],
|
||||
"message": message,
|
||||
"tempGuid": f"temp-{datetime.utcnow().timestamp()}",
|
||||
"tempGuid": f"temp-{datetime.now(timezone.utc).timestamp()}",
|
||||
}
|
||||
try:
|
||||
res = await self._api_post("/api/v1/chat/new", payload)
|
||||
@@ -417,7 +417,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
||||
)
|
||||
payload: Dict[str, Any] = {
|
||||
"chatGuid": guid,
|
||||
"tempGuid": f"temp-{datetime.utcnow().timestamp()}",
|
||||
"tempGuid": f"temp-{datetime.now(timezone.utc).timestamp()}",
|
||||
"message": chunk,
|
||||
}
|
||||
if reply_to and self._private_api_enabled and self._helper_connected:
|
||||
|
||||
@@ -1196,9 +1196,16 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
import base64
|
||||
|
||||
duration_secs = 5.0
|
||||
try:
|
||||
from mutagen.oggopus import OggOpus
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mutagen is required for Discord voice messages. "
|
||||
"Install with: pip install hermes-agent[messaging]"
|
||||
) from None
|
||||
|
||||
duration_secs = 5.0
|
||||
try:
|
||||
info = OggOpus(audio_path)
|
||||
duration_secs = info.info.length
|
||||
except Exception:
|
||||
@@ -1891,7 +1898,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Fetch full member list (requires members intent)
|
||||
try:
|
||||
members = guild.members
|
||||
if len(members) < guild.member_count:
|
||||
if guild.member_count is not None and len(members) < guild.member_count:
|
||||
members = [m async for m in guild.fetch_members(limit=None)]
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch members for guild %s: %s", guild.name, e)
|
||||
@@ -2504,7 +2511,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if isinstance(skills, str):
|
||||
return [skills]
|
||||
if isinstance(skills, list) and skills:
|
||||
return list(dict.fromkeys(skills)) # dedup, preserve order
|
||||
return list(dict.fromkeys(skills)) # ty: ignore[invalid-return-type] # dedup, preserve order
|
||||
return None
|
||||
|
||||
def _resolve_channel_prompt(self, channel_id: str, parent_id: str | None = None) -> str | None:
|
||||
@@ -3040,7 +3047,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
# Skip the mention check if the message is in a thread where
|
||||
# the bot has previously participated (auto-created or replied in).
|
||||
in_bot_thread = is_thread and thread_id in self._threads
|
||||
in_bot_thread = is_thread and thread_id is not None and thread_id in self._threads
|
||||
|
||||
if require_mention and not is_free_channel and not in_bot_thread:
|
||||
if self._client.user not in message.mentions and not mention_prefix:
|
||||
@@ -3633,7 +3640,9 @@ if DISCORD_AVAILABLE:
|
||||
)
|
||||
return
|
||||
|
||||
provider_slug = interaction.data["values"][0]
|
||||
if interaction.data is None:
|
||||
return
|
||||
provider_slug = interaction.data["values"][0] # ty: ignore[invalid-key]
|
||||
self._selected_provider = provider_slug
|
||||
provider = next(
|
||||
(p for p in self.providers if p["slug"] == provider_slug), None
|
||||
@@ -3667,8 +3676,10 @@ if DISCORD_AVAILABLE:
|
||||
)
|
||||
return
|
||||
|
||||
if interaction.data is None:
|
||||
return
|
||||
self.resolved = True
|
||||
model_id = interaction.data["values"][0]
|
||||
model_id = interaction.data["values"][0] # ty: ignore[invalid-key]
|
||||
|
||||
try:
|
||||
result_text = await self.on_model_selected(
|
||||
|
||||
@@ -532,6 +532,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image URL as part of an email body."""
|
||||
text = caption or ""
|
||||
|
||||
@@ -2170,8 +2170,8 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
ul_match = re.match(r"^[\s]*[-*+]\s+(.+)$", line)
|
||||
if ul_match:
|
||||
items = []
|
||||
while i < len(lines) and re.match(r"^[\s]*[-*+]\s+(.+)$", lines[i]):
|
||||
items.append(re.match(r"^[\s]*[-*+]\s+(.+)$", lines[i]).group(1))
|
||||
while i < len(lines) and (m := re.match(r"^[\s]*[-*+]\s+(.+)$", lines[i])):
|
||||
items.append(m.group(1))
|
||||
i += 1
|
||||
li = "".join(f"<li>{item}</li>" for item in items)
|
||||
out_lines.append(f"<ul>{li}</ul>")
|
||||
@@ -2181,8 +2181,8 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
ol_match = re.match(r"^[\s]*\d+[.)]\s+(.+)$", line)
|
||||
if ol_match:
|
||||
items = []
|
||||
while i < len(lines) and re.match(r"^[\s]*\d+[.)]\s+(.+)$", lines[i]):
|
||||
items.append(re.match(r"^[\s]*\d+[.)]\s+(.+)$", lines[i]).group(1))
|
||||
while i < len(lines) and (m := re.match(r"^[\s]*\d+[.)]\s+(.+)$", lines[i])):
|
||||
items.append(m.group(1))
|
||||
i += 1
|
||||
li = "".join(f"<li>{item}</li>" for item in items)
|
||||
out_lines.append(f"<ol>{li}</ol>")
|
||||
|
||||
@@ -1842,6 +1842,7 @@ class QQAdapter(BasePlatformAdapter):
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
else:
|
||||
raise
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
# Maximum time (seconds) to wait for reconnection before giving up on send.
|
||||
_RECONNECT_WAIT_SECONDS = 15.0
|
||||
|
||||
@@ -1690,6 +1690,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str, team_id: str = "") -> bytes:
|
||||
"""Download a Slack file and return raw bytes, with retry."""
|
||||
@@ -1715,6 +1716,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
# ── Channel mention gating ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -25,7 +25,10 @@ import hmac
|
||||
import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
|
||||
@@ -2820,6 +2820,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
)
|
||||
|
||||
sticker = msg.sticker
|
||||
if sticker is None:
|
||||
return
|
||||
emoji = sticker.emoji or ""
|
||||
set_name = sticker.set_name or ""
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ def _resolve_system_dns() -> set[str]:
|
||||
"""Return the IPv4 addresses that the OS resolver gives for api.telegram.org."""
|
||||
try:
|
||||
results = socket.getaddrinfo(_TELEGRAM_API_HOST, 443, socket.AF_INET)
|
||||
return {addr[4][0] for addr in results}
|
||||
return {str(addr[4][0]) for addr in results}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
@@ -703,7 +703,8 @@ class WeComAdapter(BasePlatformAdapter):
|
||||
elif isinstance(appmsg.get("image"), dict):
|
||||
refs.append(("image", appmsg["image"]))
|
||||
|
||||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||||
raw_quote = body.get("quote")
|
||||
quote = raw_quote if isinstance(raw_quote, dict) else {}
|
||||
quote_type = str(quote.get("msgtype") or "").lower()
|
||||
if quote_type == "image" and isinstance(quote.get("image"), dict):
|
||||
refs.append(("image", quote["image"]))
|
||||
|
||||
@@ -25,7 +25,10 @@ import subprocess
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Dict, Optional, Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
from hermes_constants import get_hermes_dir
|
||||
|
||||
|
||||
+24
-23
@@ -2859,10 +2859,12 @@ class GatewayRunner:
|
||||
return MatrixAdapter(config)
|
||||
|
||||
elif platform == Platform.API_SERVER:
|
||||
from gateway.platforms.api_server import APIServerAdapter, check_api_server_requirements
|
||||
if not check_api_server_requirements():
|
||||
try:
|
||||
import aiohttp # noqa: F401
|
||||
except ImportError:
|
||||
logger.warning("API Server: aiohttp not installed")
|
||||
return None
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
return APIServerAdapter(config)
|
||||
|
||||
elif platform == Platform.WEBHOOK:
|
||||
@@ -4429,9 +4431,10 @@ class GatewayRunner:
|
||||
# is speaking, without needing a separate tool call.
|
||||
# -----------------------------------------------------------------
|
||||
if source.platform == Platform.DISCORD:
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
adapter = self.adapters.get(Platform.DISCORD)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if guild_id and adapter and hasattr(adapter, "get_voice_channel_context"):
|
||||
if guild_id and isinstance(adapter, DiscordAdapter):
|
||||
vc_context = adapter.get_voice_channel_context(guild_id)
|
||||
if vc_context:
|
||||
context_prompt += f"\n\n{vc_context}"
|
||||
@@ -5874,7 +5877,7 @@ class GatewayRunner:
|
||||
available = "`none`, " + ", ".join(f"`{n}`" for n in personalities)
|
||||
return f"Unknown personality: `{args}`\n\nAvailable: {available}"
|
||||
|
||||
async def _handle_retry_command(self, event: MessageEvent) -> str:
|
||||
async def _handle_retry_command(self, event: MessageEvent) -> Optional[str]:
|
||||
"""Handle /retry command - re-send the last user message."""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
@@ -6024,9 +6027,10 @@ class GatewayRunner:
|
||||
"all": "TTS (voice reply to all messages)",
|
||||
}
|
||||
# Append voice channel info if connected
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if guild_id and hasattr(adapter, "get_voice_channel_info"):
|
||||
if guild_id and isinstance(adapter, DiscordAdapter):
|
||||
info = adapter.get_voice_channel_info(guild_id)
|
||||
if info:
|
||||
lines = [
|
||||
@@ -6057,8 +6061,9 @@ class GatewayRunner:
|
||||
|
||||
async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
|
||||
"""Join the user's current Discord voice channel."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
if not hasattr(adapter, "join_voice_channel"):
|
||||
if not isinstance(adapter, DiscordAdapter):
|
||||
return "Voice channels are not supported on this platform."
|
||||
|
||||
guild_id = self._get_guild_id(event)
|
||||
@@ -6073,10 +6078,8 @@ class GatewayRunner:
|
||||
|
||||
# Wire callbacks BEFORE join so voice input arriving immediately
|
||||
# after connection is not lost.
|
||||
if hasattr(adapter, "_voice_input_callback"):
|
||||
adapter._voice_input_callback = self._handle_voice_channel_input
|
||||
if hasattr(adapter, "_on_voice_disconnect"):
|
||||
adapter._on_voice_disconnect = self._handle_voice_timeout_cleanup
|
||||
adapter._voice_input_callback = self._handle_voice_channel_input
|
||||
adapter._on_voice_disconnect = self._handle_voice_timeout_cleanup
|
||||
|
||||
try:
|
||||
success = await adapter.join_voice_channel(voice_channel)
|
||||
@@ -6093,8 +6096,7 @@ class GatewayRunner:
|
||||
|
||||
if success:
|
||||
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
||||
if hasattr(adapter, "_voice_sources"):
|
||||
adapter._voice_sources[guild_id] = event.source.to_dict()
|
||||
adapter._voice_sources[guild_id] = event.source.to_dict()
|
||||
self._voice_mode[self._voice_key(event.source.platform, event.source.chat_id)] = "all"
|
||||
self._save_voice_modes()
|
||||
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False)
|
||||
@@ -6108,13 +6110,14 @@ class GatewayRunner:
|
||||
|
||||
async def _handle_voice_channel_leave(self, event: MessageEvent) -> str:
|
||||
"""Leave the Discord voice channel."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
|
||||
if not guild_id or not hasattr(adapter, "leave_voice_channel"):
|
||||
if not guild_id or not isinstance(adapter, DiscordAdapter):
|
||||
return "Not in a voice channel."
|
||||
|
||||
if not hasattr(adapter, "is_in_voice_channel") or not adapter.is_in_voice_channel(guild_id):
|
||||
if not adapter.is_in_voice_channel(guild_id):
|
||||
return "Not in a voice channel."
|
||||
|
||||
try:
|
||||
@@ -6125,8 +6128,7 @@ class GatewayRunner:
|
||||
self._voice_mode[self._voice_key(event.source.platform, event.source.chat_id)] = "off"
|
||||
self._save_voice_modes()
|
||||
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True)
|
||||
if hasattr(adapter, "_voice_input_callback"):
|
||||
adapter._voice_input_callback = None
|
||||
adapter._voice_input_callback = None
|
||||
return "Left voice channel."
|
||||
|
||||
def _handle_voice_timeout_cleanup(self, chat_id: str) -> None:
|
||||
@@ -6286,13 +6288,13 @@ class GatewayRunner:
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
|
||||
# If connected to a voice channel, play there instead of sending a file
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
guild_id = self._get_guild_id(event)
|
||||
if (guild_id
|
||||
and hasattr(adapter, "play_in_voice_channel")
|
||||
and hasattr(adapter, "is_in_voice_channel")
|
||||
and isinstance(adapter, DiscordAdapter)
|
||||
and adapter.is_in_voice_channel(guild_id)):
|
||||
await adapter.play_in_voice_channel(guild_id, actual_path)
|
||||
elif adapter and hasattr(adapter, "send_voice"):
|
||||
elif adapter:
|
||||
send_kwargs: Dict[str, Any] = {
|
||||
"chat_id": event.source.chat_id,
|
||||
"audio_path": actual_path,
|
||||
@@ -10488,6 +10490,7 @@ class GatewayRunner:
|
||||
if _timed_out_agent and hasattr(_timed_out_agent, "interrupt"):
|
||||
_timed_out_agent.interrupt(_INTERRUPT_REASON_TIMEOUT)
|
||||
|
||||
assert _agent_timeout is not None # narrowed by _idle_secs >= _agent_timeout above
|
||||
_timeout_mins = int(_agent_timeout // 60) or 1
|
||||
|
||||
# Construct a user-facing message with diagnostic context.
|
||||
@@ -10606,7 +10609,7 @@ class GatewayRunner:
|
||||
pending = None
|
||||
|
||||
if pending_event or pending:
|
||||
logger.debug("Processing pending message: '%s...'", pending[:40])
|
||||
logger.debug("Processing pending message: '%s...'", (pending or "")[:40])
|
||||
|
||||
# Clear the adapter's interrupt event so the next _run_agent call
|
||||
# doesn't immediately re-trigger the interrupt before the new agent
|
||||
@@ -10625,8 +10628,6 @@ class GatewayRunner:
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and pending_event:
|
||||
merge_pending_message_event(adapter._pending_messages, session_key, pending_event)
|
||||
elif adapter and hasattr(adapter, 'queue_message'):
|
||||
adapter.queue_message(session_key, pending)
|
||||
return result_holder[0] or {"final_response": response, "messages": history}
|
||||
|
||||
was_interrupted = result.get("interrupted")
|
||||
@@ -10708,7 +10709,7 @@ class GatewayRunner:
|
||||
history=updated_history,
|
||||
)
|
||||
if next_message is None:
|
||||
return result
|
||||
return result # ty: ignore[invalid-return-type]
|
||||
next_message_id = getattr(pending_event, "message_id", None)
|
||||
next_channel_prompt = getattr(pending_event, "channel_prompt", None)
|
||||
|
||||
|
||||
+10
-6
@@ -768,16 +768,20 @@ def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Di
|
||||
auth_store["active_provider"] = provider_id
|
||||
|
||||
|
||||
def read_credential_pool(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Return the persisted credential pool, or one provider slice."""
|
||||
def read_credential_pool() -> Dict[str, Any]:
|
||||
"""Return the entire persisted credential pool."""
|
||||
auth_store = _load_auth_store()
|
||||
pool = auth_store.get("credential_pool")
|
||||
if not isinstance(pool, dict):
|
||||
pool = {}
|
||||
if provider_id is None:
|
||||
return dict(pool)
|
||||
provider_entries = pool.get(provider_id)
|
||||
return list(provider_entries) if isinstance(provider_entries, list) else []
|
||||
return dict(pool)
|
||||
|
||||
|
||||
def read_provider_credentials(provider_id: str) -> List[Dict[str, Any]]:
|
||||
"""Return credential entries for a single provider."""
|
||||
pool = read_credential_pool()
|
||||
entries = pool.get(provider_id)
|
||||
return list(entries) if isinstance(entries, list) else []
|
||||
|
||||
|
||||
def write_credential_pool(provider_id: str, entries: List[Dict[str, Any]]) -> Path:
|
||||
|
||||
@@ -276,7 +276,7 @@ def _get_ps_exe() -> str | None:
|
||||
global _ps_exe
|
||||
if _ps_exe is False:
|
||||
_ps_exe = _find_powershell()
|
||||
return _ps_exe
|
||||
return _ps_exe if isinstance(_ps_exe, str) else None
|
||||
|
||||
|
||||
def _windows_has_image() -> bool:
|
||||
@@ -387,6 +387,8 @@ def _wayland_save(dest: Path) -> bool:
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.debug("wl-paste not installed — Wayland clipboard unavailable")
|
||||
except ImportError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug("wl-paste clipboard extraction failed: %s", e)
|
||||
dest.unlink(missing_ok=True)
|
||||
@@ -395,14 +397,17 @@ def _wayland_save(dest: Path) -> bool:
|
||||
|
||||
def _convert_to_png(path: Path) -> bool:
|
||||
"""Convert an image file to PNG in-place (requires Pillow or ImageMagick)."""
|
||||
# Try Pillow first (likely installed in the venv)
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pillow is required for clipboard image conversion. "
|
||||
"Install with: pip install hermes-agent[cli]"
|
||||
) from None
|
||||
try:
|
||||
img = Image.open(path)
|
||||
img.save(path, "PNG")
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug("Pillow BMP→PNG conversion failed: %s", e)
|
||||
|
||||
|
||||
@@ -1904,7 +1904,7 @@ def get_missing_config_fields() -> List[Dict[str, Any]]:
|
||||
config = load_config()
|
||||
missing = []
|
||||
|
||||
def _check(defaults: dict, current: dict, prefix: str = ""):
|
||||
def _check(defaults: Dict[str, Any], current: Dict[str, Any], prefix: str = ""):
|
||||
for key, default_value in defaults.items():
|
||||
if key.startswith('_'):
|
||||
continue
|
||||
@@ -2146,8 +2146,8 @@ def check_config_version() -> Tuple[int, int]:
|
||||
Returns (current_version, latest_version).
|
||||
"""
|
||||
config = load_config()
|
||||
current = config.get("_config_version", 0)
|
||||
latest = DEFAULT_CONFIG.get("_config_version", 1)
|
||||
current = int(config.get("_config_version", 0))
|
||||
latest = int(DEFAULT_CONFIG.get("_config_version", 1))
|
||||
return current, latest
|
||||
|
||||
|
||||
@@ -2867,7 +2867,7 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
return results
|
||||
|
||||
|
||||
def _deep_merge(base: dict, override: dict) -> dict:
|
||||
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Recursively merge *override* into *base*, preserving nested defaults.
|
||||
|
||||
Keys in *override* take precedence. If both values are dicts the merge
|
||||
|
||||
@@ -18,7 +18,7 @@ import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
@@ -108,7 +108,7 @@ def wait_for_registration_success(
|
||||
device_code: str,
|
||||
interval: int = 3,
|
||||
expires_in: int = 7200,
|
||||
on_waiting: Optional[callable] = None,
|
||||
on_waiting: Optional[Callable[..., Any]] = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""Block until the registration succeeds or times out.
|
||||
|
||||
|
||||
@@ -3047,6 +3047,12 @@ def _setup_wecom():
|
||||
print_success("💬 WeCom configured!")
|
||||
|
||||
|
||||
def _setup_wecom_callback():
|
||||
"""Configure WeCom Callback (self-built app) via the standard platform setup."""
|
||||
wecom_platform = next(p for p in _PLATFORMS if p["key"] == "wecom_callback")
|
||||
_setup_standard_platform(wecom_platform)
|
||||
|
||||
|
||||
def _is_service_installed() -> bool:
|
||||
"""Check if the gateway is installed as a system service."""
|
||||
if supports_systemd_services():
|
||||
|
||||
@@ -13,7 +13,7 @@ import json as _json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypedDict
|
||||
|
||||
|
||||
from hermes_cli.config import (
|
||||
@@ -748,7 +748,7 @@ def _estimate_tool_tokens() -> Dict[str, int]:
|
||||
OpenAI-format tool schema. Triggers tool discovery on first call,
|
||||
then caches the result for the rest of the process.
|
||||
|
||||
Returns an empty dict when tiktoken or the registry is unavailable.
|
||||
Returns an empty dict when the registry is unavailable.
|
||||
"""
|
||||
global _tool_token_cache
|
||||
if _tool_token_cache is not None:
|
||||
@@ -756,11 +756,12 @@ def _estimate_tool_tokens() -> Dict[str, int]:
|
||||
|
||||
try:
|
||||
import tiktoken
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
except Exception:
|
||||
logger.debug("tiktoken unavailable; skipping tool token estimation")
|
||||
_tool_token_cache = {}
|
||||
return _tool_token_cache
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"tiktoken is required for tool token estimation. "
|
||||
"Install with: pip install hermes-agent[cli]"
|
||||
) from None
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
try:
|
||||
# Trigger full tool discovery (imports all tool modules).
|
||||
@@ -1098,13 +1099,19 @@ def _detect_active_provider_index(providers: list, config: dict) -> int:
|
||||
# right catalog at picker time.
|
||||
|
||||
|
||||
def _fal_model_catalog():
|
||||
class _ImagegenBackend(TypedDict):
|
||||
display: str
|
||||
config_key: str
|
||||
catalog_fn: Callable[[], Tuple[Dict[str, Dict[str, Any]], str]]
|
||||
|
||||
|
||||
def _fal_model_catalog() -> Tuple[Dict[str, Dict[str, Any]], str]:
|
||||
"""Lazy-load the FAL model catalog from the tool module."""
|
||||
from tools.image_generation_tool import FAL_MODELS, DEFAULT_MODEL
|
||||
return FAL_MODELS, DEFAULT_MODEL
|
||||
|
||||
|
||||
IMAGEGEN_BACKENDS = {
|
||||
IMAGEGEN_BACKENDS: Dict[str, _ImagegenBackend] = {
|
||||
"fal": {
|
||||
"display": "FAL.ai",
|
||||
"config_key": "image_gen",
|
||||
|
||||
+1
-1
@@ -142,7 +142,7 @@ class _ComponentFilter(logging.Filter):
|
||||
# Used by _ComponentFilter and exposed for ``hermes logs --component``.
|
||||
COMPONENT_PREFIXES = {
|
||||
"gateway": ("gateway",),
|
||||
"agent": ("agent", "run_agent", "model_tools", "batch_runner"),
|
||||
"agent": ("agent", "run_agent", "model_tools", "scripts.batch_runner"),
|
||||
"tools": ("tools",),
|
||||
"cli": ("hermes_cli", "cli"),
|
||||
"cron": ("cron",),
|
||||
|
||||
+7
-5
@@ -40,11 +40,11 @@ dependencies = [
|
||||
modal = ["modal>=1.0.0,<2"]
|
||||
daytona = ["daytona>=0.148.0,<1"]
|
||||
dev = ["debugpy>=1.8.0,<2", "pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", "pytest-xdist>=3.0,<4", "mcp>=1.2.0,<2", "ty>=0.0.1a29,<0.0.22", "ruff"]
|
||||
messaging = ["python-telegram-bot[webhooks]>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4", "qrcode>=7.0,<8"]
|
||||
messaging = ["python-telegram-bot[webhooks]>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4", "qrcode>=7.0,<8", "mutagen>=1.45,<2", "aiohttp-socks>=0.9,<1"]
|
||||
cron = ["croniter>=6.0.0,<7"]
|
||||
slack = ["slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"]
|
||||
matrix = ["mautrix[encryption]>=0.20,<1", "Markdown>=3.6,<4", "aiosqlite>=0.20", "asyncpg>=0.29"]
|
||||
cli = ["simple-term-menu>=1.0,<2"]
|
||||
cli = ["simple-term-menu>=1.0,<2", "tiktoken>=0.7,<1", "Pillow>=10,<12"]
|
||||
tts-premium = ["elevenlabs>=1.0,<2"]
|
||||
voice = [
|
||||
# Local STT pulls in wheel-only transitive deps (ctranslate2, onnxruntime),
|
||||
@@ -58,7 +58,7 @@ pty = [
|
||||
"pywinpty>=2.0.0,<3; sys_platform == 'win32'",
|
||||
]
|
||||
honcho = ["honcho-ai>=2.0.1,<3"]
|
||||
mcp = ["mcp>=1.2.0,<2"]
|
||||
mcp = ["mcp>=1.2.0,<2", "psutil>=5.9,<7"]
|
||||
homeassistant = ["aiohttp>=3.9.0,<4"]
|
||||
sms = ["aiohttp>=3.9.0,<4"]
|
||||
acp = ["agent-client-protocol>=0.9.0,<1.0"]
|
||||
@@ -85,7 +85,9 @@ rl = [
|
||||
"fastapi>=0.104.0,<1",
|
||||
"uvicorn[standard]>=0.24.0,<1",
|
||||
"wandb>=0.15.0,<1",
|
||||
"datasets>=2.14,<3",
|
||||
]
|
||||
tts-local = ["neutts[all]", "soundfile>=0.12,<1"]
|
||||
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git@bfb0c88062450f46341bd9a5298903fc2e952a5c ; python_version >= '3.12'"]
|
||||
all = [
|
||||
"hermes-agent[modal]",
|
||||
@@ -120,13 +122,13 @@ hermes-agent = "run_agent:main"
|
||||
hermes-acp = "acp_adapter.entry:main"
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions", "cli", "hermes_constants", "hermes_state", "hermes_time", "hermes_logging", "rl_cli", "utils"]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets", "toolset_distributions", "cli", "hermes_constants", "hermes_state", "hermes_time", "hermes_logging", "utils"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
hermes_cli = ["web_dist/**/*"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["agent", "agent.*", "tools", "tools.*", "hermes_cli", "gateway", "gateway.*", "tui_gateway", "tui_gateway.*", "cron", "acp_adapter", "plugins", "plugins.*"]
|
||||
include = ["agent", "agent.*", "tools", "tools.*", "hermes_cli", "gateway", "gateway.*", "tui_gateway", "tui_gateway.*", "cron", "acp_adapter", "plugins", "plugins.*", "scripts"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
+27
-21
@@ -37,7 +37,10 @@ import time
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Callable, List, Dict, Any, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent.rate_limit_tracker import RateLimitState
|
||||
from openai import OpenAI
|
||||
import fire
|
||||
from datetime import datetime
|
||||
@@ -722,17 +725,17 @@ class AIAgent:
|
||||
provider_require_parameters: bool = False,
|
||||
provider_data_collection: str = None,
|
||||
session_id: str = None,
|
||||
tool_progress_callback: callable = None,
|
||||
tool_start_callback: callable = None,
|
||||
tool_complete_callback: callable = None,
|
||||
thinking_callback: callable = None,
|
||||
reasoning_callback: callable = None,
|
||||
clarify_callback: callable = None,
|
||||
step_callback: callable = None,
|
||||
stream_delta_callback: callable = None,
|
||||
interim_assistant_callback: callable = None,
|
||||
tool_gen_callback: callable = None,
|
||||
status_callback: callable = None,
|
||||
tool_progress_callback: Callable[..., Any] = None,
|
||||
tool_start_callback: Callable[..., Any] = None,
|
||||
tool_complete_callback: Callable[..., Any] = None,
|
||||
thinking_callback: Callable[..., Any] = None,
|
||||
reasoning_callback: Callable[..., Any] = None,
|
||||
clarify_callback: Callable[..., Any] = None,
|
||||
step_callback: Callable[..., Any] = None,
|
||||
stream_delta_callback: Callable[..., Any] = None,
|
||||
interim_assistant_callback: Callable[..., Any] = None,
|
||||
tool_gen_callback: Callable[..., Any] = None,
|
||||
status_callback: Callable[..., Any] = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
service_tier: str = None,
|
||||
@@ -1048,7 +1051,7 @@ class AIAgent:
|
||||
for quiet_logger in [
|
||||
'tools', # all tools.* (terminal, browser, web, file, etc.)
|
||||
'run_agent', # agent runner internals
|
||||
'trajectory_compressor',
|
||||
'scripts.trajectory_compressor',
|
||||
'cron', # scheduler (only relevant in daemon mode)
|
||||
'hermes_cli', # CLI helpers
|
||||
]:
|
||||
@@ -4767,7 +4770,7 @@ class AIAgent:
|
||||
def _close_request_openai_client(self, client: Any, *, reason: str) -> None:
|
||||
self._close_openai_client(client, reason=reason, shared=False)
|
||||
|
||||
def _run_codex_stream(self, api_kwargs: dict, client: Any = None, on_first_delta: callable = None):
|
||||
def _run_codex_stream(self, api_kwargs: dict, client: Any = None, on_first_delta: Callable[..., Any] = None):
|
||||
"""Execute one streaming Responses API request and return the final response."""
|
||||
import httpx as _httpx
|
||||
|
||||
@@ -5466,7 +5469,7 @@ class AIAgent:
|
||||
)
|
||||
|
||||
def _interruptible_streaming_api_call(
|
||||
self, api_kwargs: dict, *, on_first_delta: callable = None
|
||||
self, api_kwargs: dict, *, on_first_delta: Callable[..., Any] = None
|
||||
):
|
||||
"""Streaming variant of _interruptible_api_call for real-time token delivery.
|
||||
|
||||
@@ -7405,12 +7408,15 @@ class AIAgent:
|
||||
_flush_temperature = _fixed_temp
|
||||
else:
|
||||
_flush_temperature = 0.3
|
||||
_flush_llm_kwargs: dict = {}
|
||||
if _flush_temperature is not None:
|
||||
_flush_llm_kwargs["temperature"] = _flush_temperature
|
||||
try:
|
||||
response = _call_llm(
|
||||
task="flush_memories",
|
||||
messages=api_messages,
|
||||
tools=[memory_tool_def],
|
||||
temperature=_flush_temperature,
|
||||
**_flush_llm_kwargs,
|
||||
max_tokens=5120,
|
||||
# timeout resolved from auxiliary.flush_memories.timeout config
|
||||
)
|
||||
@@ -8619,9 +8625,9 @@ class AIAgent:
|
||||
self,
|
||||
user_message: str,
|
||||
system_message: str = None,
|
||||
conversation_history: List[Dict[str, Any]] = None,
|
||||
conversation_history: List[Dict[str, Any]] | None = None,
|
||||
task_id: str = None,
|
||||
stream_callback: Optional[callable] = None,
|
||||
stream_callback: Optional[Callable[..., Any]] = None,
|
||||
persist_user_message: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -10225,7 +10231,7 @@ class AIAgent:
|
||||
auth_method = "Bearer (OAuth/setup-token)" if _is_oauth_token(key) else "x-api-key (API key)"
|
||||
print(f"{self.log_prefix}🔐 Anthropic 401 — authentication failed.")
|
||||
print(f"{self.log_prefix} Auth method: {auth_method}")
|
||||
print(f"{self.log_prefix} Token prefix: {key[:12]}..." if key and len(key) > 12 else f"{self.log_prefix} Token: (empty or short)")
|
||||
print(f"{self.log_prefix} Token prefix: {str(key)[:12]}..." if key and len(str(key)) > 12 else f"{self.log_prefix} Token: (empty or short)")
|
||||
print(f"{self.log_prefix} Troubleshooting:")
|
||||
from hermes_constants import display_hermes_home as _dhh_fn
|
||||
_dhh = _dhh_fn()
|
||||
@@ -11569,7 +11575,7 @@ class AIAgent:
|
||||
messages.append(assistant_msg)
|
||||
|
||||
if reasoning_text:
|
||||
reasoning_preview = reasoning_text[:500] + "..." if len(reasoning_text) > 500 else reasoning_text
|
||||
reasoning_preview = str(reasoning_text)[:500] + "..." if len(str(reasoning_text)) > 500 else reasoning_text
|
||||
logger.warning(
|
||||
"Reasoning-only response (no visible content) "
|
||||
"after exhausting retries and fallback. "
|
||||
@@ -11908,7 +11914,7 @@ class AIAgent:
|
||||
|
||||
return result
|
||||
|
||||
def chat(self, message: str, stream_callback: Optional[callable] = None) -> str:
|
||||
def chat(self, message: str, stream_callback: Optional[Callable[..., Any]] = None) -> str:
|
||||
"""
|
||||
Simple chat interface that returns just the final response.
|
||||
|
||||
|
||||
@@ -20,9 +20,13 @@ Usage:
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
@@ -1126,7 +1130,7 @@ def main(
|
||||
num_workers: int = 4,
|
||||
resume: bool = False,
|
||||
verbose: bool = False,
|
||||
list_distributions: bool = False,
|
||||
show_distributions: bool = False,
|
||||
ephemeral_system_prompt: str = None,
|
||||
log_prefix_chars: int = 100,
|
||||
providers_allowed: str = None,
|
||||
@@ -1154,7 +1158,7 @@ def main(
|
||||
num_workers (int): Number of parallel worker processes (default: 4)
|
||||
resume (bool): Resume from checkpoint if run was interrupted (default: False)
|
||||
verbose (bool): Enable verbose logging (default: False)
|
||||
list_distributions (bool): List available toolset distributions and exit
|
||||
show_distributions (bool): List available toolset distributions and exit
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||
providers_allowed (str): Comma-separated list of OpenRouter providers to allow (e.g. "anthropic,openai")
|
||||
@@ -1186,10 +1190,10 @@ def main(
|
||||
--prefill_messages_file=configs/prefill_opus.json
|
||||
|
||||
# List available distributions
|
||||
python batch_runner.py --list_distributions
|
||||
python batch_runner.py --show_distributions
|
||||
"""
|
||||
# Handle list distributions
|
||||
if list_distributions:
|
||||
if show_distributions:
|
||||
from toolset_distributions import print_distribution_info
|
||||
|
||||
print("📊 Available Toolset Distributions")
|
||||
@@ -26,10 +26,13 @@ Usage:
|
||||
python mini_swe_runner.py --prompts_file prompts.jsonl --output_file trajectories.jsonl --env docker
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
+2
-1
@@ -26,6 +26,7 @@ import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -685,7 +686,7 @@ def get_commits(since_tag=None):
|
||||
return commits
|
||||
|
||||
|
||||
def get_pr_number(subject: str) -> str:
|
||||
def get_pr_number(subject: str) -> Optional[str]:
|
||||
"""Extract PR number from commit subject if present."""
|
||||
match = re.search(r"#(\d+)", subject)
|
||||
if match:
|
||||
|
||||
@@ -19,18 +19,23 @@ Environment Variables:
|
||||
OPENROUTER_API_KEY: API key for OpenRouter (required for agent)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import yaml
|
||||
|
||||
from hermes_constants import get_hermes_home, OPENROUTER_BASE_URL
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
_hermes_home = get_hermes_home()
|
||||
_project_env = Path(__file__).parent / '.env'
|
||||
_project_env = Path(__file__).parent.parent / '.env'
|
||||
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
@@ -60,8 +65,6 @@ from tools.rl_training_tool import get_missing_keys
|
||||
# Config Loading
|
||||
# ============================================================================
|
||||
|
||||
from hermes_constants import get_hermes_home, OPENROUTER_BASE_URL
|
||||
|
||||
DEFAULT_MODEL = "anthropic/claude-opus-4.5"
|
||||
DEFAULT_BASE_URL = OPENROUTER_BASE_URL
|
||||
|
||||
@@ -267,7 +267,7 @@ def run_compression(input_dir: Path, output_dir: Path, config_path: str):
|
||||
# Import the compressor
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from trajectory_compressor import TrajectoryCompressor, CompressionConfig
|
||||
from scripts.trajectory_compressor import TrajectoryCompressor, CompressionConfig
|
||||
|
||||
print(f"\n🗜️ Running trajectory compression...")
|
||||
print(f" Input: {input_dir}")
|
||||
|
||||
@@ -30,14 +30,18 @@ Usage:
|
||||
python trajectory_compressor.py --input=data/my_run --sample_percent=10
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import json
|
||||
import time
|
||||
import yaml
|
||||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
@@ -52,7 +56,7 @@ from agent.retry_utils import jittered_backoff
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
_hermes_home = get_hermes_home()
|
||||
_project_env = Path(__file__).parent / ".env"
|
||||
_project_env = Path(__file__).parent.parent / ".env"
|
||||
load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
|
||||
|
||||
@@ -75,7 +79,7 @@ def _effective_temperature_for_model(
|
||||
if fixed_temperature is OMIT_TEMPERATURE:
|
||||
return None # caller must omit temperature
|
||||
if fixed_temperature is not None:
|
||||
return fixed_temperature
|
||||
return cast(float, fixed_temperature)
|
||||
return requested_temperature
|
||||
|
||||
|
||||
@@ -607,11 +611,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
|
||||
if getattr(self, '_use_call_llm', False):
|
||||
from agent.auxiliary_client import call_llm
|
||||
_call_llm_kwargs: dict = {}
|
||||
if summary_temperature is not None:
|
||||
_call_llm_kwargs["temperature"] = summary_temperature
|
||||
response = call_llm(
|
||||
provider=self._llm_provider,
|
||||
model=self.config.summarization_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=summary_temperature,
|
||||
**_call_llm_kwargs,
|
||||
max_tokens=self.config.summary_target_tokens * 2,
|
||||
)
|
||||
else:
|
||||
@@ -623,20 +630,21 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
if summary_temperature is not None:
|
||||
_create_kwargs["temperature"] = summary_temperature
|
||||
response = self.client.chat.completions.create(**_create_kwargs)
|
||||
|
||||
|
||||
summary = self._coerce_summary_content(response.choices[0].message.content)
|
||||
return self._ensure_summary_prefix(summary)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
metrics.summarization_errors += 1
|
||||
self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}")
|
||||
|
||||
|
||||
if attempt < self.config.max_retries - 1:
|
||||
time.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0))
|
||||
else:
|
||||
# Fallback: create a basic summary
|
||||
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
|
||||
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
async def _generate_summary_async(self, content: str, metrics: TrajectoryMetrics) -> str:
|
||||
"""
|
||||
Generate a summary of the compressed turns using OpenRouter (async version).
|
||||
@@ -676,11 +684,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
|
||||
if getattr(self, '_use_call_llm', False):
|
||||
from agent.auxiliary_client import async_call_llm
|
||||
_async_llm_kwargs: dict = {}
|
||||
if summary_temperature is not None:
|
||||
_async_llm_kwargs["temperature"] = summary_temperature
|
||||
response = await async_call_llm(
|
||||
provider=self._llm_provider,
|
||||
model=self.config.summarization_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=summary_temperature,
|
||||
**_async_llm_kwargs,
|
||||
max_tokens=self.config.summary_target_tokens * 2,
|
||||
)
|
||||
else:
|
||||
@@ -692,20 +703,21 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
if summary_temperature is not None:
|
||||
_create_kwargs["temperature"] = summary_temperature
|
||||
response = await self._get_async_client().chat.completions.create(**_create_kwargs)
|
||||
|
||||
|
||||
summary = self._coerce_summary_content(response.choices[0].message.content)
|
||||
return self._ensure_summary_prefix(summary)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
metrics.summarization_errors += 1
|
||||
self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}")
|
||||
|
||||
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0))
|
||||
else:
|
||||
# Fallback: create a basic summary
|
||||
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
|
||||
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
def compress_trajectory(
|
||||
self,
|
||||
trajectory: List[Dict[str, str]]
|
||||
@@ -1162,3 +1162,75 @@ def test_load_pool_does_not_seed_qwen_oauth_when_no_token(tmp_path, monkeypatch)
|
||||
|
||||
assert not pool.has_credentials()
|
||||
assert pool.entries() == []
|
||||
|
||||
|
||||
def _build_pool_with_entries(tmp_path, monkeypatch, provider="openrouter", entries=None):
|
||||
"""Helper: build a CredentialPool directly without seeding side-effects."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setattr("agent.credential_pool._seed_from_singletons", lambda p, e: (False, set()))
|
||||
monkeypatch.setattr("agent.credential_pool._seed_from_env", lambda p, e: (False, set()))
|
||||
if entries is None:
|
||||
entries = [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "tok-1",
|
||||
},
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "secondary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "tok-2",
|
||||
},
|
||||
]
|
||||
_write_auth_store(tmp_path, {"version": 1, "credential_pool": {provider: entries}})
|
||||
from agent.credential_pool import load_pool
|
||||
return load_pool(provider)
|
||||
|
||||
|
||||
def test_remove_entry_removes_by_id(tmp_path, monkeypatch):
|
||||
"""remove_entry should remove the entry with matching id and return it."""
|
||||
pool = _build_pool_with_entries(tmp_path, monkeypatch)
|
||||
|
||||
removed = pool.remove_entry("cred-1")
|
||||
|
||||
assert removed is not None
|
||||
assert removed.id == "cred-1"
|
||||
remaining_ids = [e.id for e in pool.entries()]
|
||||
assert "cred-1" not in remaining_ids
|
||||
assert "cred-2" in remaining_ids
|
||||
|
||||
|
||||
def test_remove_entry_returns_none_for_unknown_id(tmp_path, monkeypatch):
|
||||
"""remove_entry returns None when no entry matches the given id."""
|
||||
pool = _build_pool_with_entries(tmp_path, monkeypatch)
|
||||
|
||||
result = pool.remove_entry("nonexistent-id")
|
||||
|
||||
assert result is None
|
||||
# Pool should still have both original entries
|
||||
assert len(pool.entries()) == 2
|
||||
|
||||
|
||||
def test_remove_entry_renumbers_priorities(tmp_path, monkeypatch):
|
||||
"""After remove_entry, remaining entries receive sequential priorities 0, 1, ..."""
|
||||
pool = _build_pool_with_entries(
|
||||
tmp_path,
|
||||
monkeypatch,
|
||||
entries=[
|
||||
{"id": "cred-1", "label": "a", "auth_type": "api_key", "priority": 0, "source": "manual", "access_token": "tok-1"},
|
||||
{"id": "cred-2", "label": "b", "auth_type": "api_key", "priority": 1, "source": "manual", "access_token": "tok-2"},
|
||||
{"id": "cred-3", "label": "c", "auth_type": "api_key", "priority": 2, "source": "manual", "access_token": "tok-3"},
|
||||
],
|
||||
)
|
||||
|
||||
pool.remove_entry("cred-2")
|
||||
|
||||
remaining = sorted(pool.entries(), key=lambda e: e.priority)
|
||||
assert [e.priority for e in remaining] == [0, 1]
|
||||
assert [e.id for e in remaining] == ["cred-1", "cred-3"]
|
||||
|
||||
@@ -164,7 +164,7 @@ class TestArceeURLMapping:
|
||||
assert "arceeai" in _PROVIDER_PREFIXES
|
||||
|
||||
def test_trajectory_compressor_detects_arcee(self):
|
||||
import trajectory_compressor as tc
|
||||
import scripts.trajectory_compressor as tc
|
||||
comp = tc.TrajectoryCompressor.__new__(tc.TrajectoryCompressor)
|
||||
comp.config = types.SimpleNamespace(base_url="https://api.arcee.ai/api/v1")
|
||||
assert comp._detect_provider() == "arcee"
|
||||
|
||||
@@ -104,7 +104,7 @@ def main():
|
||||
test_file = create_test_dataset()
|
||||
|
||||
print(f"\n📝 To run the test manually:")
|
||||
print(f" python batch_runner.py \\")
|
||||
print(f" python scripts/batch_runner.py \\")
|
||||
print(f" --dataset_file={test_file} \\")
|
||||
print(f" --batch_size=2 \\")
|
||||
print(f" --run_name={run_name} \\")
|
||||
@@ -112,7 +112,7 @@ def main():
|
||||
print(f" --num_workers=2")
|
||||
|
||||
print(f"\n💡 Or test with different distributions:")
|
||||
print(f" python batch_runner.py --list_distributions")
|
||||
print(f" python scripts/batch_runner.py --list_distributions")
|
||||
|
||||
print(f"\n🔍 After running, you can verify output with:")
|
||||
print(f" python tests/test_batch_runner.py --verify")
|
||||
|
||||
@@ -30,7 +30,7 @@ from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
import traceback
|
||||
|
||||
# Add project root to path to import batch_runner
|
||||
# Add project root to path to import scripts.batch_runner
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ def test_current_implementation():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
# Import here to avoid issues if module changes
|
||||
from batch_runner import BatchRunner
|
||||
from scripts.batch_runner import BatchRunner
|
||||
|
||||
checkpoint_file = output_dir / "checkpoint.json"
|
||||
|
||||
@@ -229,7 +229,7 @@ def test_interruption_and_resume():
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
from batch_runner import BatchRunner
|
||||
from scripts.batch_runner import BatchRunner
|
||||
|
||||
checkpoint_file = output_dir / "checkpoint.json"
|
||||
|
||||
|
||||
@@ -8,11 +8,7 @@ from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# batch_runner uses relative imports, ensure project root is on path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from batch_runner import BatchRunner, _process_batch_worker
|
||||
from scripts.batch_runner import BatchRunner, _process_batch_worker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -173,7 +169,7 @@ class TestBatchWorkerResumeBehavior:
|
||||
"toolsets_used": [],
|
||||
}
|
||||
|
||||
monkeypatch.setattr("batch_runner._process_single_prompt", lambda *args, **kwargs: prompt_result)
|
||||
monkeypatch.setattr("scripts.batch_runner._process_single_prompt", lambda *args, **kwargs: prompt_result)
|
||||
|
||||
result = _process_batch_worker((
|
||||
1,
|
||||
|
||||
@@ -14,7 +14,7 @@ def test_run_task_kimi_omits_temperature():
|
||||
)
|
||||
mock_openai.return_value = client
|
||||
|
||||
from mini_swe_runner import MiniSWERunner
|
||||
from scripts.mini_swe_runner import MiniSWERunner
|
||||
|
||||
runner = MiniSWERunner(
|
||||
model="kimi-for-coding",
|
||||
@@ -42,7 +42,7 @@ def test_run_task_public_moonshot_kimi_k2_5_omits_temperature():
|
||||
)
|
||||
mock_openai.return_value = client
|
||||
|
||||
from mini_swe_runner import MiniSWERunner
|
||||
from scripts.mini_swe_runner import MiniSWERunner
|
||||
|
||||
runner = MiniSWERunner(
|
||||
model="kimi-k2.5",
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from trajectory_compressor import (
|
||||
from scripts.trajectory_compressor import (
|
||||
CompressionConfig,
|
||||
TrajectoryMetrics,
|
||||
AggregateMetrics,
|
||||
@@ -25,8 +25,8 @@ def test_import_loads_env_from_hermes_home(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
|
||||
sys.modules.pop("trajectory_compressor", None)
|
||||
importlib.import_module("trajectory_compressor")
|
||||
sys.modules.pop("scripts.trajectory_compressor", None)
|
||||
importlib.import_module("scripts.trajectory_compressor")
|
||||
|
||||
assert os.getenv("OPENROUTER_API_KEY") == "from-hermes-home"
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class TestAsyncClientLazyCreation:
|
||||
|
||||
def test_async_client_none_after_init(self):
|
||||
"""async_client should be None after __init__ (not eagerly created)."""
|
||||
from trajectory_compressor import TrajectoryCompressor
|
||||
from scripts.trajectory_compressor import TrajectoryCompressor
|
||||
|
||||
comp = TrajectoryCompressor.__new__(TrajectoryCompressor)
|
||||
comp.config = MagicMock()
|
||||
@@ -36,7 +36,7 @@ class TestAsyncClientLazyCreation:
|
||||
|
||||
def test_get_async_client_creates_new_client(self):
|
||||
"""_get_async_client() should create a fresh AsyncOpenAI instance."""
|
||||
from trajectory_compressor import TrajectoryCompressor
|
||||
from scripts.trajectory_compressor import TrajectoryCompressor
|
||||
|
||||
comp = TrajectoryCompressor.__new__(TrajectoryCompressor)
|
||||
comp.config = MagicMock()
|
||||
@@ -57,7 +57,7 @@ class TestAsyncClientLazyCreation:
|
||||
def test_get_async_client_creates_fresh_each_call(self):
|
||||
"""Each call to _get_async_client() creates a NEW client instance,
|
||||
so it binds to the current event loop."""
|
||||
from trajectory_compressor import TrajectoryCompressor
|
||||
from scripts.trajectory_compressor import TrajectoryCompressor
|
||||
|
||||
comp = TrajectoryCompressor.__new__(TrajectoryCompressor)
|
||||
comp.config = MagicMock()
|
||||
@@ -91,7 +91,7 @@ class TestSourceLineVerification:
|
||||
def _read_file() -> str:
|
||||
import os
|
||||
base = os.path.dirname(os.path.dirname(__file__))
|
||||
with open(os.path.join(base, "trajectory_compressor.py")) as f:
|
||||
with open(os.path.join(base, "scripts", "trajectory_compressor.py")) as f:
|
||||
return f.read()
|
||||
|
||||
def test_no_eager_async_openai_in_init(self):
|
||||
@@ -119,7 +119,7 @@ class TestSourceLineVerification:
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_summary_async_kimi_omits_temperature():
|
||||
"""Kimi models should have temperature omitted — server manages it."""
|
||||
from trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
|
||||
from scripts.trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
|
||||
|
||||
config = CompressionConfig(
|
||||
summarization_model="kimi-for-coding",
|
||||
@@ -147,7 +147,7 @@ async def test_generate_summary_async_kimi_omits_temperature():
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_summary_async_public_moonshot_kimi_k2_5_omits_temperature():
|
||||
"""kimi-k2.5 on the public Moonshot API should not get a forced temperature."""
|
||||
from trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
|
||||
from scripts.trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
|
||||
|
||||
config = CompressionConfig(
|
||||
summarization_model="kimi-k2.5",
|
||||
@@ -176,7 +176,7 @@ async def test_generate_summary_async_public_moonshot_kimi_k2_5_omits_temperatur
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_summary_async_public_moonshot_cn_kimi_k2_5_omits_temperature():
|
||||
"""kimi-k2.5 on api.moonshot.cn should not get a forced temperature."""
|
||||
from trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
|
||||
from scripts.trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
|
||||
|
||||
config = CompressionConfig(
|
||||
summarization_model="kimi-k2.5",
|
||||
|
||||
@@ -87,7 +87,7 @@ class TestTrajectoryCompressorNullGuard:
|
||||
|
||||
def test_null_base_url_does_not_crash(self):
|
||||
"""base_url=None should not crash _detect_provider()."""
|
||||
from trajectory_compressor import CompressionConfig, TrajectoryCompressor
|
||||
from scripts.trajectory_compressor import CompressionConfig, TrajectoryCompressor
|
||||
|
||||
config = CompressionConfig()
|
||||
config.base_url = None
|
||||
@@ -101,7 +101,7 @@ class TestTrajectoryCompressorNullGuard:
|
||||
|
||||
def test_config_loading_null_base_url_keeps_default(self):
|
||||
"""YAML ``summarization: {base_url: null}`` should keep default."""
|
||||
from trajectory_compressor import CompressionConfig
|
||||
from scripts.trajectory_compressor import CompressionConfig
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
config = CompressionConfig()
|
||||
|
||||
@@ -5,6 +5,8 @@ terminates processes, and handles edge cases on failure paths.
|
||||
Inspired by PR #715 (0xbyt4).
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -118,6 +120,29 @@ class TestStopTrainingRunProcesses:
|
||||
trainer.terminate.assert_not_called()
|
||||
|
||||
|
||||
class TestRunStateLogFileFields:
|
||||
|
||||
def test_log_file_fields_default_none(self):
|
||||
"""All three log_file fields should default to None."""
|
||||
state = _make_run_state()
|
||||
assert state.api_log_file is None
|
||||
assert state.trainer_log_file is None
|
||||
assert state.env_log_file is None
|
||||
|
||||
def test_accepts_file_handle_for_api_log(self):
|
||||
"""api_log_file should accept an open file-like object."""
|
||||
api_log = io.StringIO()
|
||||
state = _make_run_state(api_log_file=api_log)
|
||||
assert state.api_log_file is api_log
|
||||
|
||||
def test_log_file_fields_present_in_dataclass(self):
|
||||
"""All three field names must be declared on the RunState dataclass."""
|
||||
field_names = {f.name for f in dataclasses.fields(RunState)}
|
||||
assert "api_log_file" in field_names
|
||||
assert "trainer_log_file" in field_names
|
||||
assert "env_log_file" in field_names
|
||||
|
||||
|
||||
class TestStopTrainingRunStatus:
|
||||
"""Verify status transitions in _stop_training_run."""
|
||||
|
||||
|
||||
+3
-3
@@ -16,7 +16,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
import unicodedata
|
||||
from typing import Optional
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -228,10 +228,10 @@ class _ApprovalEntry:
|
||||
|
||||
|
||||
_gateway_queues: dict[str, list] = {} # session_key → [_ApprovalEntry, …]
|
||||
_gateway_notify_cbs: dict[str, object] = {} # session_key → callable(approval_data)
|
||||
_gateway_notify_cbs: Dict[str, Callable[[Dict[str, Any]], None]] = {}
|
||||
|
||||
|
||||
def register_gateway_notify(session_key: str, cb) -> None:
|
||||
def register_gateway_notify(session_key: str, cb: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""Register a per-session callback for sending approval requests to the user.
|
||||
|
||||
The callback signature is ``cb(approval_data: dict) -> None`` where
|
||||
|
||||
@@ -891,7 +891,7 @@ BROWSER_TOOL_SCHEMAS = [
|
||||
# Utility Functions
|
||||
# ============================================================================
|
||||
|
||||
def _create_local_session(task_id: str) -> Dict[str, str]:
|
||||
def _create_local_session(task_id: str) -> Dict[str, Any]:
|
||||
import uuid
|
||||
session_name = f"h_{uuid.uuid4().hex[:10]}"
|
||||
logger.info("Created local browser session %s for task %s",
|
||||
@@ -904,7 +904,7 @@ def _create_local_session(task_id: str) -> Dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
def _create_cdp_session(task_id: str, cdp_url: str) -> Dict[str, str]:
|
||||
def _create_cdp_session(task_id: str, cdp_url: str) -> Dict[str, Any]:
|
||||
"""Create a session that connects to a user-supplied CDP endpoint."""
|
||||
import uuid
|
||||
session_name = f"cdp_{uuid.uuid4().hex[:10]}"
|
||||
@@ -918,7 +918,7 @@ def _create_cdp_session(task_id: str, cdp_url: str) -> Dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
|
||||
def _get_session_info(task_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get or create session info for the given task.
|
||||
|
||||
@@ -1687,7 +1687,7 @@ def browser_scroll(direction: str, task_id: Optional[str] = None) -> str:
|
||||
from tools.browser_camofox import camofox_scroll
|
||||
# Camofox REST API doesn't support pixel args; use repeated calls
|
||||
_SCROLL_REPEATS = 5
|
||||
result = None
|
||||
result: str = ""
|
||||
for _ in range(_SCROLL_REPEATS):
|
||||
result = camofox_scroll(direction, task_id)
|
||||
return result
|
||||
|
||||
@@ -68,7 +68,7 @@ def _scan_cron_prompt(prompt: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _origin_from_env() -> Optional[Dict[str, str]]:
|
||||
def _origin_from_env() -> Optional[Dict[str, Optional[str]]]:
|
||||
from gateway.session_context import get_session_env
|
||||
origin_platform = get_session_env("HERMES_SESSION_PLATFORM")
|
||||
origin_chat_id = get_session_env("HERMES_SESSION_CHAT_ID")
|
||||
|
||||
@@ -29,7 +29,7 @@ from concurrent.futures import (
|
||||
TimeoutError as FuturesTimeoutError,
|
||||
as_completed,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from toolsets import TOOLSETS
|
||||
from tools import file_state
|
||||
@@ -584,7 +584,7 @@ def _build_child_progress_callback(
|
||||
depth: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
toolsets: Optional[List[str]] = None,
|
||||
) -> Optional[callable]:
|
||||
) -> Optional[Callable[..., Any]]:
|
||||
"""Build a callback that relays child agent tool calls to the parent display.
|
||||
|
||||
Two display paths:
|
||||
@@ -1602,7 +1602,7 @@ def delegate_task(
|
||||
|
||||
n_tasks = len(task_list)
|
||||
# Track goal labels for progress display (truncated for readability)
|
||||
task_labels = [t["goal"][:40] for t in task_list]
|
||||
task_labels = [str(t["goal"] or "")[:40] for t in task_list]
|
||||
|
||||
# Save parent tool names BEFORE any child construction mutates the global.
|
||||
# _build_child_agent() calls AIAgent() which calls get_tool_definitions(),
|
||||
|
||||
@@ -245,7 +245,7 @@ class _ThreadedProcessHandle:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def wait(self, timeout: float | None = None) -> int:
|
||||
def wait(self, timeout: float | None = None) -> int | None:
|
||||
self._done.wait(timeout=timeout)
|
||||
return self._returncode
|
||||
|
||||
@@ -755,7 +755,7 @@ class BaseEnvironment(ABC):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _prepare_command(self, command: str) -> tuple[str, str | None]:
|
||||
def _prepare_command(self, command: str) -> tuple[str | None, str | None]:
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available."""
|
||||
from tools.terminal_tool import _transform_sudo_command
|
||||
|
||||
|
||||
@@ -26,10 +26,11 @@ import os
|
||||
import datetime
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import fal_client
|
||||
import httpx
|
||||
|
||||
from tools.debug_helpers import DebugSession
|
||||
from tools.managed_tool_gateway import resolve_managed_tool_gateway
|
||||
@@ -348,21 +349,27 @@ class _ManagedFalSyncClient:
|
||||
|
||||
self._queue_url_format = _normalize_fal_queue_url_format(queue_run_origin)
|
||||
self._sync_client = sync_client_class(key=key)
|
||||
self._http_client = getattr(self._sync_client, "_client", None)
|
||||
self._maybe_retry_request = getattr(client_module, "_maybe_retry_request", None)
|
||||
self._raise_for_status = getattr(client_module, "_raise_for_status", None)
|
||||
self._request_handle_class = getattr(client_module, "SyncRequestHandle", None)
|
||||
self._add_hint_header = getattr(client_module, "add_hint_header", None)
|
||||
self._add_priority_header = getattr(client_module, "add_priority_header", None)
|
||||
self._add_timeout_header = getattr(client_module, "add_timeout_header", None)
|
||||
|
||||
if self._http_client is None:
|
||||
http_client: Optional[httpx.Client] = getattr(self._sync_client, "_client", None)
|
||||
maybe_retry: Optional[Callable[..., httpx.Response]] = getattr(client_module, "_maybe_retry_request", None)
|
||||
raise_for_status: Optional[Callable[[httpx.Response], None]] = getattr(client_module, "_raise_for_status", None)
|
||||
request_handle_class: Optional[Type[Any]] = getattr(client_module, "SyncRequestHandle", None)
|
||||
|
||||
if http_client is None:
|
||||
raise RuntimeError("fal_client.SyncClient._client is required for managed FAL gateway mode")
|
||||
if self._maybe_retry_request is None or self._raise_for_status is None:
|
||||
if maybe_retry is None or raise_for_status is None:
|
||||
raise RuntimeError("fal_client.client request helpers are required for managed FAL gateway mode")
|
||||
if self._request_handle_class is None:
|
||||
if request_handle_class is None:
|
||||
raise RuntimeError("fal_client.client.SyncRequestHandle is required for managed FAL gateway mode")
|
||||
|
||||
self._http_client: httpx.Client = http_client
|
||||
self._maybe_retry_request: Callable[..., httpx.Response] = maybe_retry
|
||||
self._raise_for_status: Callable[[httpx.Response], None] = raise_for_status
|
||||
self._request_handle_class: Type[Any] = request_handle_class
|
||||
self._add_hint_header: Optional[Callable[..., Any]] = getattr(client_module, "add_hint_header", None)
|
||||
self._add_priority_header: Optional[Callable[..., Any]] = getattr(client_module, "add_priority_header", None)
|
||||
self._add_timeout_header: Optional[Callable[..., Any]] = getattr(client_module, "add_timeout_header", None)
|
||||
|
||||
def submit(
|
||||
self,
|
||||
application: str,
|
||||
|
||||
+8
-4
@@ -1506,11 +1506,15 @@ def _snapshot_child_pids() -> set:
|
||||
# Fallback: psutil
|
||||
try:
|
||||
import psutil
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"psutil is required for MCP child process tracking. "
|
||||
"Install with: pip install hermes-agent[mcp]"
|
||||
) from None
|
||||
try:
|
||||
return {c.pid for c in psutil.Process(my_pid).children()}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return set()
|
||||
except psutil.Error:
|
||||
return set()
|
||||
|
||||
|
||||
def _mcp_loop_exception_handler(loop, context):
|
||||
|
||||
@@ -174,6 +174,7 @@ async def _run_reference_model_safe(
|
||||
error_msg = f"{model} failed after {max_retries} attempts: {error_str}"
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
return model, error_msg, False
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
|
||||
async def _run_aggregator_model(
|
||||
|
||||
@@ -71,12 +71,13 @@ def main():
|
||||
|
||||
ref_text = ref_text_path.read_text(encoding="utf-8").strip()
|
||||
|
||||
# Import and run NeuTTS
|
||||
try:
|
||||
from neutts import NeuTTS
|
||||
except ImportError:
|
||||
print("Error: neutts not installed. Run: python -m pip install -U neutts[all]", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
raise ImportError(
|
||||
"neutts is required for local TTS synthesis. "
|
||||
"Install with: pip install hermes-agent[tts-local]"
|
||||
) from None
|
||||
|
||||
tts = NeuTTS(
|
||||
backbone_repo=args.model,
|
||||
@@ -93,9 +94,12 @@ def main():
|
||||
|
||||
try:
|
||||
import soundfile as sf
|
||||
sf.write(str(out_path), wav, 24000)
|
||||
except ImportError:
|
||||
_write_wav(str(out_path), wav, 24000)
|
||||
raise ImportError(
|
||||
"soundfile is required for audio output. "
|
||||
"Install with: pip install hermes-agent[tts-local]"
|
||||
) from None
|
||||
sf.write(str(out_path), wav, 24000)
|
||||
|
||||
print(f"OK: {out_path}", file=sys.stderr)
|
||||
|
||||
|
||||
@@ -31,7 +31,10 @@ Usage:
|
||||
import difflib
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Any
|
||||
from typing import List, Optional, Tuple, Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tools.file_operations import PatchResult
|
||||
from enum import Enum
|
||||
|
||||
|
||||
|
||||
@@ -335,12 +335,18 @@ class ProcessRegistry:
|
||||
)
|
||||
|
||||
if use_pty:
|
||||
# Try PTY mode for interactive CLI tools
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
from winpty import PtyProcess as _PtyProcessCls
|
||||
else:
|
||||
from ptyprocess import PtyProcess as _PtyProcessCls
|
||||
except ImportError:
|
||||
pkg = "winpty" if _IS_WINDOWS else "ptyprocess"
|
||||
raise ImportError(
|
||||
f"{pkg} is required for PTY mode. "
|
||||
"Install with: pip install hermes-agent[pty]"
|
||||
) from None
|
||||
try:
|
||||
user_shell = _find_shell()
|
||||
pty_env = _sanitize_subprocess_env(os.environ, env_vars)
|
||||
pty_env["PYTHONUNBUFFERED"] = "1"
|
||||
@@ -371,8 +377,6 @@ class ProcessRegistry:
|
||||
self._write_checkpoint()
|
||||
return session
|
||||
|
||||
except ImportError:
|
||||
logger.warning("ptyprocess not installed, falling back to pipe mode")
|
||||
except Exception as e:
|
||||
logger.warning("PTY spawn failed (%s), falling back to pipe mode", e)
|
||||
|
||||
|
||||
@@ -137,6 +137,10 @@ class RunState:
|
||||
api_process: Optional[subprocess.Popen] = None
|
||||
trainer_process: Optional[subprocess.Popen] = None
|
||||
env_process: Optional[subprocess.Popen] = None
|
||||
# Log file handles (kept open while subprocess runs; closed by _stop_training_run)
|
||||
api_log_file: Optional[Any] = None
|
||||
trainer_log_file: Optional[Any] = None
|
||||
env_log_file: Optional[Any] = None
|
||||
|
||||
|
||||
# Global state
|
||||
|
||||
@@ -443,7 +443,7 @@ def session_search(
|
||||
)
|
||||
|
||||
# Summarize all sessions in parallel
|
||||
async def _summarize_all() -> List[Union[str, Exception]]:
|
||||
async def _summarize_all() -> List[Union[Optional[str], BaseException]]:
|
||||
"""Summarize all sessions with bounded concurrency."""
|
||||
max_concurrency = min(_get_session_search_max_concurrency(), max(1, len(tasks)))
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
@@ -27,7 +27,7 @@ import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
||||
|
||||
@@ -639,7 +639,7 @@ def scan_skill(skill_path: Path, source: str = "community") -> ScanResult:
|
||||
)
|
||||
|
||||
|
||||
def should_allow_install(result: ScanResult, force: bool = False) -> Tuple[bool, str]:
|
||||
def should_allow_install(result: ScanResult, force: bool = False) -> Tuple[Optional[bool], str]:
|
||||
"""
|
||||
Determine whether a skill should be installed based on scan result and trust.
|
||||
|
||||
|
||||
@@ -410,6 +410,7 @@ def _resolve_tirith_path(configured_path: str) -> str:
|
||||
|
||||
# Fast path: successfully resolved on a previous call.
|
||||
if _resolved_path is not None and _resolved_path is not _INSTALL_FAILED:
|
||||
assert isinstance(_resolved_path, str)
|
||||
return _resolved_path
|
||||
|
||||
expanded = os.path.expanduser(configured_path)
|
||||
|
||||
@@ -31,6 +31,7 @@ import os
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
@@ -53,7 +54,7 @@ def _safe_find_spec(module_name: str) -> bool:
|
||||
try:
|
||||
return _ilu.find_spec(module_name) is not None
|
||||
except (ImportError, ValueError):
|
||||
return module_name in globals() or module_name in os.sys.modules
|
||||
return module_name in globals() or module_name in sys.modules
|
||||
|
||||
|
||||
_HAS_FASTER_WHISPER = _safe_find_spec("faster_whisper")
|
||||
|
||||
@@ -318,15 +318,14 @@ def _resize_image_for_vision(image_path: Path, mime_type: Optional[str] = None,
|
||||
else:
|
||||
data_url = None # defer full encode; try Pillow resize first
|
||||
|
||||
# Attempt auto-resize with Pillow (soft dependency)
|
||||
try:
|
||||
from PIL import Image
|
||||
import io as _io
|
||||
except ImportError:
|
||||
logger.info("Pillow not installed — cannot auto-resize oversized image")
|
||||
if data_url is None:
|
||||
data_url = _image_to_base64_data_url(image_path, mime_type=mime_type)
|
||||
return data_url # caller will raise the size error
|
||||
raise ImportError(
|
||||
"Pillow is required for image resizing. "
|
||||
"Install with: pip install hermes-agent[cli]"
|
||||
) from None
|
||||
import io as _io
|
||||
|
||||
logger.info("Image file is %.1f MB (estimated base64 %.1f MB, limit %.1f MB), auto-resizing...",
|
||||
file_size / (1024 * 1024), estimated_b64 / (1024 * 1024),
|
||||
|
||||
+6
-6
@@ -1720,8 +1720,8 @@ async def web_crawl_tool(
|
||||
metadata = {}
|
||||
|
||||
# Extract data from the item
|
||||
if hasattr(item, 'model_dump'):
|
||||
# Pydantic model - use model_dump to get dict
|
||||
from pydantic import BaseModel
|
||||
if isinstance(item, BaseModel):
|
||||
item_dict = item.model_dump()
|
||||
content_markdown = item_dict.get('markdown')
|
||||
content_html = item_dict.get('html')
|
||||
@@ -1730,15 +1730,15 @@ async def web_crawl_tool(
|
||||
# Regular object with attributes
|
||||
content_markdown = getattr(item, 'markdown', None)
|
||||
content_html = getattr(item, 'html', None)
|
||||
|
||||
|
||||
# Handle metadata - convert to dict if it's an object
|
||||
metadata_obj = getattr(item, 'metadata', {})
|
||||
if hasattr(metadata_obj, 'model_dump'):
|
||||
if isinstance(metadata_obj, BaseModel):
|
||||
metadata = metadata_obj.model_dump()
|
||||
elif hasattr(metadata_obj, '__dict__'):
|
||||
metadata = metadata_obj.__dict__
|
||||
elif isinstance(metadata_obj, dict):
|
||||
metadata = metadata_obj
|
||||
elif hasattr(metadata_obj, '__dict__'):
|
||||
metadata = metadata_obj.__dict__
|
||||
else:
|
||||
metadata = {}
|
||||
elif isinstance(item, dict):
|
||||
|
||||
@@ -19,7 +19,7 @@ Usage:
|
||||
all_dists = list_distributions()
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
import random
|
||||
from toolsets import validate_toolset
|
||||
|
||||
@@ -220,7 +220,7 @@ DISTRIBUTIONS = {
|
||||
}
|
||||
|
||||
|
||||
def get_distribution(name: str) -> Optional[Dict[str, any]]:
|
||||
def get_distribution(name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get a toolset distribution by name.
|
||||
|
||||
|
||||
+7
-4
@@ -652,7 +652,7 @@ def create_custom_toolset(
|
||||
|
||||
|
||||
|
||||
def get_toolset_info(name: str) -> Dict[str, Any]:
|
||||
def get_toolset_info(name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get detailed information about a toolset including resolved tools.
|
||||
|
||||
@@ -689,6 +689,8 @@ if __name__ == "__main__":
|
||||
print("-" * 40)
|
||||
for name, toolset in get_all_toolsets().items():
|
||||
info = get_toolset_info(name)
|
||||
if not info:
|
||||
continue
|
||||
composite = "[composite]" if info["is_composite"] else "[leaf]"
|
||||
print(f" {composite} {name:20} - {toolset['description']}")
|
||||
print(f" Tools: {len(info['resolved_tools'])} total")
|
||||
@@ -715,6 +717,7 @@ if __name__ == "__main__":
|
||||
includes=["terminal", "vision"]
|
||||
)
|
||||
custom_info = get_toolset_info("my_custom")
|
||||
print(" Created 'my_custom' toolset:")
|
||||
print(f" Description: {custom_info['description']}")
|
||||
print(f" Resolved tools: {', '.join(custom_info['resolved_tools'])}")
|
||||
if custom_info:
|
||||
print(" Created 'my_custom' toolset:")
|
||||
print(f" Description: {custom_info['description']}")
|
||||
print(f" Resolved tools: {', '.join(custom_info['resolved_tools'])}")
|
||||
|
||||
@@ -12,6 +12,7 @@ import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
@@ -33,7 +34,7 @@ except Exception:
|
||||
from tui_gateway.render import make_stream_renderer, render_diff, render_message
|
||||
|
||||
_sessions: dict[str, dict] = {}
|
||||
_methods: dict[str, callable] = {}
|
||||
_methods: dict[str, Callable[..., Any]] = {}
|
||||
_pending: dict[str, tuple[str, threading.Event]] = {}
|
||||
_answers: dict[str, str] = {}
|
||||
_db = None
|
||||
@@ -237,10 +238,16 @@ def _estimate_image_tokens(width: int, height: int) -> int:
|
||||
|
||||
|
||||
def _image_meta(path: Path) -> dict:
|
||||
meta = {"name": path.name}
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pillow is required for image metadata extraction. "
|
||||
"Install with: pip install hermes-agent[cli]"
|
||||
) from None
|
||||
|
||||
meta = {"name": path.name}
|
||||
try:
|
||||
with Image.open(path) as img:
|
||||
width, height = img.size
|
||||
meta["width"] = int(width)
|
||||
|
||||
Reference in New Issue
Block a user