Compare commits
190 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b54591ddda | |||
| 934fc9df22 | |||
| 5847c180c6 | |||
| 93a0c0cddd | |||
| 23e8fdd167 | |||
| 3268b98779 | |||
| 20f381cfb6 | |||
| 77bfa252b9 | |||
| f24c00a5bf | |||
| 463239ed85 | |||
| 60cce9ca6d | |||
| 2d57946ee9 | |||
| 5f32fd8b6d | |||
| 3ea039684e | |||
| 63f0ec96ec | |||
| 1cacaccca6 | |||
| 773f3c1137 | |||
| 0cc784068d | |||
| f1b4d0b280 | |||
| 5254d0bba1 | |||
| 21c20aeaa5 | |||
| dc095f8491 | |||
| 621fd80b1e | |||
| 2b8fd9a8e3 | |||
| fef710aca8 | |||
| 4ae1334287 | |||
| db3e3aa6c5 | |||
| 633488e0c0 | |||
| 0de200cf4d | |||
| f6fdb18fe6 | |||
| b177b4abad | |||
| 232ba441d7 | |||
| 34e120bcbb | |||
| 779f8df6a6 | |||
| 62abb453d3 | |||
| 735a6e7651 | |||
| e5ddca1c8b | |||
| 214827a594 | |||
| fd0e1aac72 | |||
| 678e0bd9cc | |||
| 8ccd14a0d4 | |||
| 6c611c852e | |||
| f882dabf19 | |||
| 973aa9b549 | |||
| 2316b8dc98 | |||
| 259208bfe4 | |||
| 47c5c97654 | |||
| df9020dfa3 | |||
| c6fb7f6463 | |||
| 672dc1666f | |||
| 5b11570517 | |||
| ff87a566c4 | |||
| 9e3752df36 | |||
| 15bf0b4af2 | |||
| 28b3764d1e | |||
| 62f1c2b622 | |||
| 71cff92eb7 | |||
| 1337c9efd8 | |||
| 747612fb3e | |||
| 84d99f7754 | |||
| 4524cddc72 | |||
| f4e8772de4 | |||
| 39fe9e8533 | |||
| d5b64ebdb3 | |||
| f8ceadbad0 | |||
| c36136084a | |||
| 4a93cfd889 | |||
| f46b35e3d1 | |||
| e6417cb7bc | |||
| 08081e5969 | |||
| 30120f05a6 | |||
| 6f85283553 | |||
| 9a177d6f4b | |||
| 6761021fb4 | |||
| 00c5e77724 | |||
| 69045711c1 | |||
| 9938d27e27 | |||
| d36b3d498d | |||
| 0c182211a1 | |||
| f4c012873c | |||
| 8ac5baf2d8 | |||
| c54db79edc | |||
| 2119b68799 | |||
| fd687d0967 | |||
| 12bc86d9c9 | |||
| 9e0f86cd3b | |||
| 883f6c81a2 | |||
| b89177668e | |||
| 9f51de7261 | |||
| a05a4afa53 | |||
| db9e512424 | |||
| 8ce66a01ee | |||
| f9a61a0d9e | |||
| ba9f82946d | |||
| 0614969f7b | |||
| f6ff6639e8 | |||
| 861869cb48 | |||
| 23bc642c82 | |||
| 9c322f7f59 | |||
| b14a07315b | |||
| 4f4e2671ac | |||
| ff3473a37c | |||
| cb7690b2b5 | |||
| 95939a1b51 | |||
| 85ef09e520 | |||
| 6b1adb7eb1 | |||
| db362dbd4c | |||
| 282df107a5 | |||
| 9f6bccd76a | |||
| 168a8e2e35 | |||
| a86b487349 | |||
| 53d1043a50 | |||
| 6c24d76533 | |||
| 30b73bdf34 | |||
| 31db8c28a4 | |||
| f549981293 | |||
| 2a6dbb25b2 | |||
| 0fd0eb93e8 | |||
| 88a48037d1 | |||
| dc11b86e4b | |||
| 26bedf973b | |||
| fc5443d854 | |||
| 799114ac8b | |||
| 70ea13eb40 | |||
| 0bc5aba5d0 | |||
| f8a3e37f54 | |||
| 3229e434b8 | |||
| 24f61d006a | |||
| c050c2d552 | |||
| 81cd367aec | |||
| e099117a3b | |||
| 2536ff328b | |||
| f3a074339d | |||
| ea053e8afd | |||
| e052c74727 | |||
| a6dc73fa07 | |||
| c3ea620796 | |||
| 7b140b31e6 | |||
| fa89b65230 | |||
| ed0c7194ed | |||
| dc44e183e6 | |||
| 79c81b2244 | |||
| df5c61b37c | |||
| b2bdaecf9b | |||
| 3fab72f1e1 | |||
| e1824ef8a6 | |||
| f3a38c90fc | |||
| a748257bf5 | |||
| 8fb618234f | |||
| 5a2fcaab39 | |||
| c207a6b302 | |||
| 7dc9281f05 | |||
| 2d18b077e1 | |||
| eb8226daab | |||
| 60710bc8f8 | |||
| 7f485f588e | |||
| f8e4233e67 | |||
| eff0d23dd9 | |||
| f10e26f731 | |||
| 1114841a2c | |||
| 5319bb6ac4 | |||
| 80a243efe6 | |||
| c1d1699a64 | |||
| 889c3e2877 | |||
| 895fe5a5d3 | |||
| 21ad98b74c | |||
| 3325e51e53 | |||
| 588d4c293c | |||
| 88951215d3 | |||
| 4422637e7a | |||
| 6d8286f396 | |||
| 94af51f621 | |||
| e5dc569daa | |||
| 14738e0872 | |||
| d2e2d6e2a2 | |||
| ee73b6bf27 | |||
| 429c44e377 | |||
| 1441525016 | |||
| 2054ffdaeb | |||
| 0d23ad7a15 | |||
| 9ec3a7a21b | |||
| 577b477a78 | |||
| fbdce27b9a | |||
| a50550fdb4 | |||
| fbd752b92b | |||
| 6d2cfc24e9 | |||
| e5186a0bad | |||
| f43c078f9e | |||
| 2eb778119d | |||
| a182d12778 |
@@ -0,0 +1,39 @@
|
||||
name: Docs Site Checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'website/**'
|
||||
- '.github/workflows/docs-site-checks.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
docs-site-checks:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
cache: npm
|
||||
cache-dependency-path: website/package-lock.json
|
||||
|
||||
- name: Install website dependencies
|
||||
run: npm ci
|
||||
working-directory: website
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install ascii-guard
|
||||
run: python -m pip install ascii-guard
|
||||
|
||||
- name: Lint docs diagrams
|
||||
run: npm run lint:diagrams
|
||||
working-directory: website
|
||||
|
||||
- name: Build Docusaurus
|
||||
run: npm run build
|
||||
working-directory: website
|
||||
@@ -42,19 +42,16 @@ def _setup_logging() -> None:
|
||||
|
||||
def _load_env() -> None:
|
||||
"""Load .env from HERMES_HOME (default ``~/.hermes``)."""
|
||||
from dotenv import load_dotenv
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
env_file = hermes_home / ".env"
|
||||
if env_file.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=env_file, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=env_file, encoding="latin-1")
|
||||
logging.getLogger(__name__).info("Loaded env from %s", env_file)
|
||||
loaded = load_hermes_dotenv(hermes_home=hermes_home)
|
||||
if loaded:
|
||||
for env_file in loaded:
|
||||
logging.getLogger(__name__).info("Loaded env from %s", env_file)
|
||||
else:
|
||||
logging.getLogger(__name__).info(
|
||||
"No .env found at %s, using system env", env_file
|
||||
"No .env found at %s, using system env", hermes_home / ".env"
|
||||
)
|
||||
|
||||
|
||||
|
||||
+230
-37
@@ -102,30 +102,15 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
|
||||
|
||||
def read_claude_code_credentials() -> Optional[Dict[str, Any]]:
|
||||
"""Read credentials from Claude Code's config files.
|
||||
"""Read refreshable Claude Code OAuth credentials from ~/.claude/.credentials.json.
|
||||
|
||||
Checks two locations (in order):
|
||||
1. ~/.claude.json — top-level primaryApiKey (native binary, v2.x)
|
||||
2. ~/.claude/.credentials.json — claudeAiOauth block (npm/legacy installs)
|
||||
This intentionally excludes ~/.claude.json primaryApiKey. Opencode's
|
||||
subscription flow is OAuth/setup-token based with refreshable credentials,
|
||||
and native direct Anthropic provider usage should follow that path rather
|
||||
than auto-detecting Claude's first-party managed key.
|
||||
|
||||
Returns dict with {accessToken, refreshToken?, expiresAt?} or None.
|
||||
"""
|
||||
# 1. Native binary (v2.x): ~/.claude.json with top-level primaryApiKey
|
||||
claude_json = Path.home() / ".claude.json"
|
||||
if claude_json.exists():
|
||||
try:
|
||||
data = json.loads(claude_json.read_text(encoding="utf-8"))
|
||||
primary_key = data.get("primaryApiKey", "")
|
||||
if primary_key:
|
||||
return {
|
||||
"accessToken": primary_key,
|
||||
"refreshToken": "",
|
||||
"expiresAt": 0, # Managed keys don't have a user-visible expiry
|
||||
}
|
||||
except (json.JSONDecodeError, OSError, IOError) as e:
|
||||
logger.debug("Failed to read ~/.claude.json: %s", e)
|
||||
|
||||
# 2. Legacy/npm installs: ~/.claude/.credentials.json
|
||||
cred_path = Path.home() / ".claude" / ".credentials.json"
|
||||
if cred_path.exists():
|
||||
try:
|
||||
@@ -138,6 +123,7 @@ def read_claude_code_credentials() -> Optional[Dict[str, Any]]:
|
||||
"accessToken": access_token,
|
||||
"refreshToken": oauth_data.get("refreshToken", ""),
|
||||
"expiresAt": oauth_data.get("expiresAt", 0),
|
||||
"source": "claude_code_credentials_file",
|
||||
}
|
||||
except (json.JSONDecodeError, OSError, IOError) as e:
|
||||
logger.debug("Failed to read ~/.claude/.credentials.json: %s", e)
|
||||
@@ -145,6 +131,20 @@ def read_claude_code_credentials() -> Optional[Dict[str, Any]]:
|
||||
return None
|
||||
|
||||
|
||||
def read_claude_managed_key() -> Optional[str]:
|
||||
"""Read Claude's native managed key from ~/.claude.json for diagnostics only."""
|
||||
claude_json = Path.home() / ".claude.json"
|
||||
if claude_json.exists():
|
||||
try:
|
||||
data = json.loads(claude_json.read_text(encoding="utf-8"))
|
||||
primary_key = data.get("primaryApiKey", "")
|
||||
if isinstance(primary_key, str) and primary_key.strip():
|
||||
return primary_key.strip()
|
||||
except (json.JSONDecodeError, OSError, IOError) as e:
|
||||
logger.debug("Failed to read ~/.claude.json: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def is_claude_code_token_valid(creds: Dict[str, Any]) -> bool:
|
||||
"""Check if Claude Code credentials have a non-expired access token."""
|
||||
import time
|
||||
@@ -236,6 +236,72 @@ def _write_claude_code_credentials(access_token: str, refresh_token: str, expire
|
||||
logger.debug("Failed to write refreshed credentials: %s", e)
|
||||
|
||||
|
||||
def _resolve_claude_code_token_from_credentials(creds: Optional[Dict[str, Any]] = None) -> Optional[str]:
|
||||
"""Resolve a token from Claude Code credential files, refreshing if needed."""
|
||||
creds = creds or read_claude_code_credentials()
|
||||
if creds and is_claude_code_token_valid(creds):
|
||||
logger.debug("Using Claude Code credentials (auto-detected)")
|
||||
return creds["accessToken"]
|
||||
if creds:
|
||||
logger.debug("Claude Code credentials expired — attempting refresh")
|
||||
refreshed = _refresh_oauth_token(creds)
|
||||
if refreshed:
|
||||
return refreshed
|
||||
logger.debug("Token refresh failed — re-run 'claude setup-token' to reauthenticate")
|
||||
return None
|
||||
|
||||
|
||||
def _prefer_refreshable_claude_code_token(env_token: str, creds: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Prefer Claude Code creds when a persisted env OAuth token would shadow refresh.
|
||||
|
||||
Hermes historically persisted setup tokens into ANTHROPIC_TOKEN. That makes
|
||||
later refresh impossible because the static env token wins before we ever
|
||||
inspect Claude Code's refreshable credential file. If we have a refreshable
|
||||
Claude Code credential record, prefer it over the static env OAuth token.
|
||||
"""
|
||||
if not env_token or not _is_oauth_token(env_token) or not isinstance(creds, dict):
|
||||
return None
|
||||
if not creds.get("refreshToken"):
|
||||
return None
|
||||
|
||||
resolved = _resolve_claude_code_token_from_credentials(creds)
|
||||
if resolved and resolved != env_token:
|
||||
logger.debug(
|
||||
"Preferring Claude Code credential file over static env OAuth token so refresh can proceed"
|
||||
)
|
||||
return resolved
|
||||
return None
|
||||
|
||||
|
||||
def get_anthropic_token_source(token: Optional[str] = None) -> str:
|
||||
"""Best-effort source classification for an Anthropic credential token."""
|
||||
token = (token or "").strip()
|
||||
if not token:
|
||||
return "none"
|
||||
|
||||
env_token = os.getenv("ANTHROPIC_TOKEN", "").strip()
|
||||
if env_token and env_token == token:
|
||||
return "anthropic_token_env"
|
||||
|
||||
cc_env_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
|
||||
if cc_env_token and cc_env_token == token:
|
||||
return "claude_code_oauth_token_env"
|
||||
|
||||
creds = read_claude_code_credentials()
|
||||
if creds and creds.get("accessToken") == token:
|
||||
return str(creds.get("source") or "claude_code_credentials")
|
||||
|
||||
managed_key = read_claude_managed_key()
|
||||
if managed_key and managed_key == token:
|
||||
return "claude_json_primary_api_key"
|
||||
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY", "").strip()
|
||||
if api_key and api_key == token:
|
||||
return "anthropic_api_key_env"
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def resolve_anthropic_token() -> Optional[str]:
|
||||
"""Resolve an Anthropic token from all available sources.
|
||||
|
||||
@@ -248,28 +314,28 @@ def resolve_anthropic_token() -> Optional[str]:
|
||||
|
||||
Returns the token string or None.
|
||||
"""
|
||||
creds = read_claude_code_credentials()
|
||||
|
||||
# 1. Hermes-managed OAuth/setup token env var
|
||||
token = os.getenv("ANTHROPIC_TOKEN", "").strip()
|
||||
if token:
|
||||
preferred = _prefer_refreshable_claude_code_token(token, creds)
|
||||
if preferred:
|
||||
return preferred
|
||||
return token
|
||||
|
||||
# 2. CLAUDE_CODE_OAUTH_TOKEN (used by Claude Code for setup-tokens)
|
||||
cc_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
|
||||
if cc_token:
|
||||
preferred = _prefer_refreshable_claude_code_token(cc_token, creds)
|
||||
if preferred:
|
||||
return preferred
|
||||
return cc_token
|
||||
|
||||
# 3. Claude Code credential file
|
||||
creds = read_claude_code_credentials()
|
||||
if creds and is_claude_code_token_valid(creds):
|
||||
logger.debug("Using Claude Code credentials (auto-detected)")
|
||||
return creds["accessToken"]
|
||||
elif creds:
|
||||
# Token expired — attempt to refresh
|
||||
logger.debug("Claude Code credentials expired — attempting refresh")
|
||||
refreshed = _refresh_oauth_token(creds)
|
||||
if refreshed:
|
||||
return refreshed
|
||||
logger.debug("Token refresh failed — re-run 'claude setup-token' to reauthenticate")
|
||||
resolved_claude_token = _resolve_claude_code_token_from_credentials(creds)
|
||||
if resolved_claude_token:
|
||||
return resolved_claude_token
|
||||
|
||||
# 4. Regular API key, or a legacy OAuth token saved in ANTHROPIC_API_KEY.
|
||||
# This remains as a compatibility fallback for pre-migration Hermes configs.
|
||||
@@ -354,6 +420,68 @@ def _sanitize_tool_id(tool_id: str) -> str:
|
||||
return sanitized or "tool_0"
|
||||
|
||||
|
||||
def _convert_openai_image_part_to_anthropic(part: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Convert an OpenAI-style image block to Anthropic's image source format."""
|
||||
image_data = part.get("image_url", {})
|
||||
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
|
||||
if not isinstance(url, str) or not url.strip():
|
||||
return None
|
||||
url = url.strip()
|
||||
|
||||
if url.startswith("data:"):
|
||||
header, sep, data = url.partition(",")
|
||||
if sep and ";base64" in header:
|
||||
media_type = header[5:].split(";", 1)[0] or "image/png"
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
|
||||
if url.startswith("http://") or url.startswith("https://"):
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _convert_user_content_part_to_anthropic(part: Any) -> Optional[Dict[str, Any]]:
|
||||
if isinstance(part, dict):
|
||||
ptype = part.get("type")
|
||||
if ptype == "text":
|
||||
block = {"type": "text", "text": part.get("text", "")}
|
||||
if isinstance(part.get("cache_control"), dict):
|
||||
block["cache_control"] = dict(part["cache_control"])
|
||||
return block
|
||||
if ptype == "image_url":
|
||||
return _convert_openai_image_part_to_anthropic(part)
|
||||
if ptype == "image" and part.get("source"):
|
||||
return dict(part)
|
||||
if ptype == "image" and part.get("data"):
|
||||
media_type = part.get("mimeType") or part.get("media_type") or "image/png"
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": part.get("data", ""),
|
||||
},
|
||||
}
|
||||
if ptype == "tool_result":
|
||||
return dict(part)
|
||||
elif part is not None:
|
||||
return {"type": "text", "text": str(part)}
|
||||
return None
|
||||
|
||||
|
||||
def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
|
||||
"""Convert OpenAI tool definitions to Anthropic format."""
|
||||
if not tools:
|
||||
@@ -369,6 +497,66 @@ def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
|
||||
return result
|
||||
|
||||
|
||||
def _image_source_from_openai_url(url: str) -> Dict[str, str]:
|
||||
"""Convert an OpenAI-style image URL/data URL into Anthropic image source."""
|
||||
url = str(url or "").strip()
|
||||
if not url:
|
||||
return {"type": "url", "url": ""}
|
||||
|
||||
if url.startswith("data:"):
|
||||
header, _, data = url.partition(",")
|
||||
media_type = "image/jpeg"
|
||||
if header.startswith("data:"):
|
||||
mime_part = header[len("data:"):].split(";", 1)[0].strip()
|
||||
if mime_part.startswith("image/"):
|
||||
media_type = mime_part
|
||||
return {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
return {"type": "url", "url": url}
|
||||
|
||||
|
||||
def _convert_content_part_to_anthropic(part: Any) -> Optional[Dict[str, Any]]:
|
||||
"""Convert a single OpenAI-style content part to Anthropic format."""
|
||||
if part is None:
|
||||
return None
|
||||
if isinstance(part, str):
|
||||
return {"type": "text", "text": part}
|
||||
if not isinstance(part, dict):
|
||||
return {"type": "text", "text": str(part)}
|
||||
|
||||
ptype = part.get("type")
|
||||
|
||||
if ptype == "input_text":
|
||||
block: Dict[str, Any] = {"type": "text", "text": part.get("text", "")}
|
||||
elif ptype in {"image_url", "input_image"}:
|
||||
image_value = part.get("image_url", {})
|
||||
url = image_value.get("url", "") if isinstance(image_value, dict) else str(image_value or "")
|
||||
block = {"type": "image", "source": _image_source_from_openai_url(url)}
|
||||
else:
|
||||
block = dict(part)
|
||||
|
||||
if isinstance(part.get("cache_control"), dict) and "cache_control" not in block:
|
||||
block["cache_control"] = dict(part["cache_control"])
|
||||
return block
|
||||
|
||||
|
||||
def _convert_content_to_anthropic(content: Any) -> Any:
|
||||
"""Convert OpenAI-style multimodal content arrays to Anthropic blocks."""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
|
||||
converted = []
|
||||
for part in content:
|
||||
block = _convert_content_part_to_anthropic(part)
|
||||
if block is not None:
|
||||
converted.append(block)
|
||||
return converted
|
||||
|
||||
|
||||
def convert_messages_to_anthropic(
|
||||
messages: List[Dict],
|
||||
) -> Tuple[Optional[Any], List[Dict]]:
|
||||
@@ -405,11 +593,9 @@ def convert_messages_to_anthropic(
|
||||
blocks = []
|
||||
if content:
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
blocks.append(dict(part))
|
||||
elif part is not None:
|
||||
blocks.append({"type": "text", "text": str(part)})
|
||||
converted_content = _convert_content_to_anthropic(content)
|
||||
if isinstance(converted_content, list):
|
||||
blocks.extend(converted_content)
|
||||
else:
|
||||
blocks.append({"type": "text", "text": str(content)})
|
||||
for tc in m.get("tool_calls", []):
|
||||
@@ -458,7 +644,14 @@ def convert_messages_to_anthropic(
|
||||
continue
|
||||
|
||||
# Regular user message
|
||||
result.append({"role": "user", "content": content})
|
||||
if isinstance(content, list):
|
||||
converted_blocks = _convert_content_to_anthropic(content)
|
||||
result.append({
|
||||
"role": "user",
|
||||
"content": converted_blocks or [{"type": "text", "text": ""}],
|
||||
})
|
||||
else:
|
||||
result.append({"role": "user", "content": content})
|
||||
|
||||
# Strip orphaned tool_use blocks (no matching tool_result follows)
|
||||
tool_result_ids = set()
|
||||
|
||||
+535
-126
@@ -1,4 +1,4 @@
|
||||
"""Shared auxiliary OpenAI client for cheap/fast side tasks.
|
||||
"""Shared auxiliary client router for side tasks.
|
||||
|
||||
Provides a single resolution chain so every consumer (context compression,
|
||||
session search, web extraction, vision analysis, browser vision) picks up
|
||||
@@ -10,26 +10,30 @@ Resolution order for text tasks (auto mode):
|
||||
3. Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY)
|
||||
4. Codex OAuth (Responses API via chatgpt.com with gpt-5.3-codex,
|
||||
wrapped to look like a chat.completions client)
|
||||
5. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, MiniMax-CN)
|
||||
— checked via PROVIDER_REGISTRY entries with auth_type='api_key'
|
||||
6. None
|
||||
5. Native Anthropic
|
||||
6. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, MiniMax-CN)
|
||||
7. None
|
||||
|
||||
Resolution order for vision/multimodal tasks (auto mode):
|
||||
1. OpenRouter
|
||||
2. Nous Portal
|
||||
3. Codex OAuth (gpt-5.3-codex supports vision via Responses API)
|
||||
4. Custom endpoint (for local vision models: Qwen-VL, LLaVA, Pixtral, etc.)
|
||||
5. None (API-key providers like z.ai/Kimi/MiniMax are skipped —
|
||||
they may not support multimodal)
|
||||
1. Selected main provider, if it is one of the supported vision backends below
|
||||
2. OpenRouter
|
||||
3. Nous Portal
|
||||
4. Codex OAuth (gpt-5.3-codex supports vision via Responses API)
|
||||
5. Native Anthropic
|
||||
6. Custom endpoint (for local vision models: Qwen-VL, LLaVA, Pixtral, etc.)
|
||||
7. None
|
||||
|
||||
Per-task provider overrides (e.g. AUXILIARY_VISION_PROVIDER,
|
||||
CONTEXT_COMPRESSION_PROVIDER) can force a specific provider for each task:
|
||||
"openrouter", "nous", "codex", or "main" (= steps 3-5).
|
||||
CONTEXT_COMPRESSION_PROVIDER) can force a specific provider for each task.
|
||||
Default "auto" follows the chains above.
|
||||
|
||||
Per-task model overrides (e.g. AUXILIARY_VISION_MODEL,
|
||||
AUXILIARY_WEB_EXTRACT_MODEL) let callers use a different model slug
|
||||
than the provider's default.
|
||||
|
||||
Per-task direct endpoint overrides (e.g. AUXILIARY_VISION_BASE_URL,
|
||||
AUXILIARY_VISION_API_KEY) let callers route a specific auxiliary task to a
|
||||
custom OpenAI-compatible endpoint without touching the main model settings.
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -74,11 +78,15 @@ auxiliary_is_nous: bool = False
|
||||
_OPENROUTER_MODEL = "google/gemini-3-flash-preview"
|
||||
_NOUS_MODEL = "gemini-3-flash"
|
||||
_NOUS_DEFAULT_BASE_URL = "https://inference-api.nousresearch.com/v1"
|
||||
_ANTHROPIC_DEFAULT_BASE_URL = "https://api.anthropic.com"
|
||||
_AUTH_JSON_PATH = get_hermes_home() / "auth.json"
|
||||
|
||||
# Codex fallback: uses the Responses API (the only endpoint the Codex
|
||||
# OAuth token can access) with a fast model for auxiliary tasks.
|
||||
_CODEX_AUX_MODEL = "gpt-5.3-codex"
|
||||
# ChatGPT-backed Codex accounts currently reject gpt-5.3-codex for these
|
||||
# auxiliary flows, while gpt-5.2-codex remains broadly available and supports
|
||||
# vision via Responses.
|
||||
_CODEX_AUX_MODEL = "gpt-5.2-codex"
|
||||
_CODEX_AUX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
@@ -309,6 +317,114 @@ class AsyncCodexAuxiliaryClient:
|
||||
self.base_url = sync_wrapper.base_url
|
||||
|
||||
|
||||
class _AnthropicCompletionsAdapter:
|
||||
"""OpenAI-client-compatible adapter for Anthropic Messages API."""
|
||||
|
||||
def __init__(self, real_client: Any, model: str):
|
||||
self._client = real_client
|
||||
self._model = model
|
||||
|
||||
def create(self, **kwargs) -> Any:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs, normalize_anthropic_response
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model = kwargs.get("model", self._model)
|
||||
tools = kwargs.get("tools")
|
||||
tool_choice = kwargs.get("tool_choice")
|
||||
max_tokens = kwargs.get("max_tokens") or kwargs.get("max_completion_tokens") or 2000
|
||||
temperature = kwargs.get("temperature")
|
||||
|
||||
normalized_tool_choice = None
|
||||
if isinstance(tool_choice, str):
|
||||
normalized_tool_choice = tool_choice
|
||||
elif isinstance(tool_choice, dict):
|
||||
choice_type = str(tool_choice.get("type", "")).lower()
|
||||
if choice_type == "function":
|
||||
normalized_tool_choice = tool_choice.get("function", {}).get("name")
|
||||
elif choice_type in {"auto", "required", "none"}:
|
||||
normalized_tool_choice = choice_type
|
||||
|
||||
anthropic_kwargs = build_anthropic_kwargs(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
reasoning_config=None,
|
||||
tool_choice=normalized_tool_choice,
|
||||
)
|
||||
if temperature is not None:
|
||||
anthropic_kwargs["temperature"] = temperature
|
||||
|
||||
response = self._client.messages.create(**anthropic_kwargs)
|
||||
assistant_message, finish_reason = normalize_anthropic_response(response)
|
||||
|
||||
usage = None
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
prompt_tokens = getattr(response.usage, "input_tokens", 0) or 0
|
||||
completion_tokens = getattr(response.usage, "output_tokens", 0) or 0
|
||||
total_tokens = getattr(response.usage, "total_tokens", 0) or (prompt_tokens + completion_tokens)
|
||||
usage = SimpleNamespace(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
choice = SimpleNamespace(
|
||||
index=0,
|
||||
message=assistant_message,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
return SimpleNamespace(
|
||||
choices=[choice],
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class _AnthropicChatShim:
|
||||
def __init__(self, adapter: _AnthropicCompletionsAdapter):
|
||||
self.completions = adapter
|
||||
|
||||
|
||||
class AnthropicAuxiliaryClient:
|
||||
"""OpenAI-client-compatible wrapper over a native Anthropic client."""
|
||||
|
||||
def __init__(self, real_client: Any, model: str, api_key: str, base_url: str):
|
||||
self._real_client = real_client
|
||||
adapter = _AnthropicCompletionsAdapter(real_client, model)
|
||||
self.chat = _AnthropicChatShim(adapter)
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
def close(self):
|
||||
close_fn = getattr(self._real_client, "close", None)
|
||||
if callable(close_fn):
|
||||
close_fn()
|
||||
|
||||
|
||||
class _AsyncAnthropicCompletionsAdapter:
|
||||
def __init__(self, sync_adapter: _AnthropicCompletionsAdapter):
|
||||
self._sync = sync_adapter
|
||||
|
||||
async def create(self, **kwargs) -> Any:
|
||||
import asyncio
|
||||
return await asyncio.to_thread(self._sync.create, **kwargs)
|
||||
|
||||
|
||||
class _AsyncAnthropicChatShim:
|
||||
def __init__(self, adapter: _AsyncAnthropicCompletionsAdapter):
|
||||
self.completions = adapter
|
||||
|
||||
|
||||
class AsyncAnthropicAuxiliaryClient:
|
||||
def __init__(self, sync_wrapper: "AnthropicAuxiliaryClient"):
|
||||
sync_adapter = sync_wrapper.chat.completions
|
||||
async_adapter = _AsyncAnthropicCompletionsAdapter(sync_adapter)
|
||||
self.chat = _AsyncAnthropicChatShim(async_adapter)
|
||||
self.api_key = sync_wrapper.api_key
|
||||
self.base_url = sync_wrapper.base_url
|
||||
|
||||
|
||||
def _read_nous_auth() -> Optional[dict]:
|
||||
"""Read and validate ~/.hermes/auth.json for an active Nous provider.
|
||||
|
||||
@@ -380,6 +496,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
break
|
||||
if not api_key:
|
||||
continue
|
||||
if provider_id == "anthropic":
|
||||
return _try_anthropic()
|
||||
|
||||
# Resolve base URL (with optional env-var override)
|
||||
# Kimi Code keys (sk-kimi-) need api.kimi.com/coding/v1
|
||||
env_url = ""
|
||||
@@ -418,6 +537,17 @@ def _get_auxiliary_provider(task: str = "") -> str:
|
||||
return "auto"
|
||||
|
||||
|
||||
def _get_auxiliary_env_override(task: str, suffix: str) -> Optional[str]:
|
||||
"""Read an auxiliary env override from AUXILIARY_* or CONTEXT_* prefixes."""
|
||||
if not task:
|
||||
return None
|
||||
for prefix in ("AUXILIARY_", "CONTEXT_"):
|
||||
val = os.getenv(f"{prefix}{task.upper()}_{suffix}", "").strip()
|
||||
if val:
|
||||
return val
|
||||
return None
|
||||
|
||||
|
||||
def _try_openrouter() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not or_key:
|
||||
@@ -465,9 +595,44 @@ def _read_main_model() -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Resolve the active custom/main endpoint the same way the main CLI does.
|
||||
|
||||
This covers both env-driven OPENAI_BASE_URL setups and config-saved custom
|
||||
endpoints where the base URL lives in config.yaml instead of the live
|
||||
environment.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
|
||||
runtime = resolve_runtime_provider(requested="custom")
|
||||
except Exception as exc:
|
||||
logger.debug("Auxiliary client: custom runtime resolution failed: %s", exc)
|
||||
return None, None
|
||||
|
||||
custom_base = runtime.get("base_url")
|
||||
custom_key = runtime.get("api_key")
|
||||
if not isinstance(custom_base, str) or not custom_base.strip():
|
||||
return None, None
|
||||
if not isinstance(custom_key, str) or not custom_key.strip():
|
||||
return None, None
|
||||
|
||||
custom_base = custom_base.strip().rstrip("/")
|
||||
if "openrouter.ai" in custom_base.lower():
|
||||
# requested='custom' falls back to OpenRouter when no custom endpoint is
|
||||
# configured. Treat that as "no custom endpoint" for auxiliary routing.
|
||||
return None, None
|
||||
|
||||
return custom_base, custom_key.strip()
|
||||
|
||||
|
||||
def _current_custom_base_url() -> str:
|
||||
custom_base, _ = _resolve_custom_runtime()
|
||||
return custom_base or ""
|
||||
|
||||
|
||||
def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
custom_base = os.getenv("OPENAI_BASE_URL")
|
||||
custom_key = os.getenv("OPENAI_API_KEY")
|
||||
custom_base, custom_key = _resolve_custom_runtime()
|
||||
if not custom_base or not custom_key:
|
||||
return None, None
|
||||
model = _read_main_model() or "gpt-4o-mini"
|
||||
@@ -484,6 +649,22 @@ def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
|
||||
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
|
||||
|
||||
|
||||
def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
||||
try:
|
||||
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
||||
except ImportError:
|
||||
return None, None
|
||||
|
||||
token = resolve_anthropic_token()
|
||||
if not token:
|
||||
return None, None
|
||||
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get("anthropic", "claude-haiku-4-5-20251001")
|
||||
logger.debug("Auxiliary client: Anthropic native (%s)", model)
|
||||
real_client = build_anthropic_client(token, _ANTHROPIC_DEFAULT_BASE_URL)
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, _ANTHROPIC_DEFAULT_BASE_URL), model
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
|
||||
if forced == "openrouter":
|
||||
@@ -546,6 +727,8 @@ def _to_async_client(sync_client, model: str):
|
||||
|
||||
if isinstance(sync_client, CodexAuxiliaryClient):
|
||||
return AsyncCodexAuxiliaryClient(sync_client), model
|
||||
if isinstance(sync_client, AnthropicAuxiliaryClient):
|
||||
return AsyncAnthropicAuxiliaryClient(sync_client), model
|
||||
|
||||
async_kwargs = {
|
||||
"api_key": sync_client.api_key,
|
||||
@@ -564,6 +747,8 @@ def resolve_provider_client(
|
||||
model: str = None,
|
||||
async_mode: bool = False,
|
||||
raw_codex: bool = False,
|
||||
explicit_base_url: str = None,
|
||||
explicit_api_key: str = None,
|
||||
) -> Tuple[Optional[Any], Optional[str]]:
|
||||
"""Central router: given a provider name and optional model, return a
|
||||
configured client with the correct auth, base URL, and API format.
|
||||
@@ -585,6 +770,8 @@ def resolve_provider_client(
|
||||
instead of wrapping in CodexAuxiliaryClient. Use this when
|
||||
the caller needs direct access to responses.stream() (e.g.,
|
||||
the main agent loop).
|
||||
explicit_base_url: Optional direct OpenAI-compatible endpoint.
|
||||
explicit_api_key: Optional API key paired with explicit_base_url.
|
||||
|
||||
Returns:
|
||||
(client, resolved_model) or (None, None) if auth is unavailable.
|
||||
@@ -661,6 +848,22 @@ def resolve_provider_client(
|
||||
|
||||
# ── Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY) ───────────
|
||||
if provider == "custom":
|
||||
if explicit_base_url:
|
||||
custom_base = explicit_base_url.strip()
|
||||
custom_key = (
|
||||
(explicit_api_key or "").strip()
|
||||
or os.getenv("OPENAI_API_KEY", "").strip()
|
||||
)
|
||||
if not custom_base or not custom_key:
|
||||
logger.warning(
|
||||
"resolve_provider_client: explicit custom endpoint requested "
|
||||
"but no API key was found (set explicit_api_key or OPENAI_API_KEY)"
|
||||
)
|
||||
return None, None
|
||||
final_model = model or _read_main_model() or "gpt-4o-mini"
|
||||
client = OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
return (_to_async_client(client, final_model) if async_mode
|
||||
else (client, final_model))
|
||||
# Try custom first, then codex, then API-key providers
|
||||
for try_fn in (_try_custom_endpoint, _try_codex,
|
||||
_resolve_api_key_provider):
|
||||
@@ -686,6 +889,14 @@ def resolve_provider_client(
|
||||
return None, None
|
||||
|
||||
if pconfig.auth_type == "api_key":
|
||||
if provider == "anthropic":
|
||||
client, default_model = _try_anthropic()
|
||||
if client is None:
|
||||
logger.warning("resolve_provider_client: anthropic requested but no Anthropic credentials found")
|
||||
return None, None
|
||||
final_model = model or default_model
|
||||
return (_to_async_client(client, final_model) if async_mode else (client, final_model))
|
||||
|
||||
# Find the first configured API key
|
||||
api_key = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
@@ -749,10 +960,13 @@ def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optiona
|
||||
Callers may override the returned model with a per-task env var
|
||||
(e.g. CONTEXT_COMPRESSION_MODEL, AUXILIARY_WEB_EXTRACT_MODEL).
|
||||
"""
|
||||
forced = _get_auxiliary_provider(task)
|
||||
if forced != "auto":
|
||||
return resolve_provider_client(forced)
|
||||
return resolve_provider_client("auto")
|
||||
provider, model, base_url, api_key = _resolve_task_provider_model(task or None)
|
||||
return resolve_provider_client(
|
||||
provider,
|
||||
model=model,
|
||||
explicit_base_url=base_url,
|
||||
explicit_api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def get_async_text_auxiliary_client(task: str = ""):
|
||||
@@ -762,54 +976,154 @@ def get_async_text_auxiliary_client(task: str = ""):
|
||||
(AsyncCodexAuxiliaryClient, model) which wraps the Responses API.
|
||||
Returns (None, None) when no provider is available.
|
||||
"""
|
||||
forced = _get_auxiliary_provider(task)
|
||||
if forced != "auto":
|
||||
return resolve_provider_client(forced, async_mode=True)
|
||||
return resolve_provider_client("auto", async_mode=True)
|
||||
provider, model, base_url, api_key = _resolve_task_provider_model(task or None)
|
||||
return resolve_provider_client(
|
||||
provider,
|
||||
model=model,
|
||||
async_mode=True,
|
||||
explicit_base_url=base_url,
|
||||
explicit_api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks.
|
||||
_VISION_AUTO_PROVIDER_ORDER = (
|
||||
"openrouter",
|
||||
"nous",
|
||||
"openai-codex",
|
||||
"anthropic",
|
||||
"custom",
|
||||
)
|
||||
|
||||
Checks AUXILIARY_VISION_PROVIDER for a forced provider, otherwise
|
||||
auto-detects. Callers may override the returned model with
|
||||
AUXILIARY_VISION_MODEL.
|
||||
|
||||
In auto mode, only providers known to support multimodal are tried:
|
||||
OpenRouter, Nous Portal, and Codex OAuth (gpt-5.3-codex supports
|
||||
vision via the Responses API). Custom endpoints and API-key
|
||||
providers are skipped — they may not handle vision input. To use
|
||||
them, set AUXILIARY_VISION_PROVIDER explicitly.
|
||||
"""
|
||||
forced = _get_auxiliary_provider("vision")
|
||||
if forced != "auto":
|
||||
return resolve_provider_client(forced)
|
||||
# Auto: try providers known to support multimodal first, then fall
|
||||
# back to the user's custom endpoint. Many local models (Qwen-VL,
|
||||
# LLaVA, Pixtral, etc.) support vision — skipping them entirely
|
||||
# caused silent failures for local-only users.
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex,
|
||||
_try_custom_endpoint):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.debug("Auxiliary vision client: none available")
|
||||
def _normalize_vision_provider(provider: Optional[str]) -> str:
|
||||
provider = (provider or "auto").strip().lower()
|
||||
if provider == "codex":
|
||||
return "openai-codex"
|
||||
if provider == "main":
|
||||
return "custom"
|
||||
return provider
|
||||
|
||||
|
||||
def _resolve_strict_vision_backend(provider: str) -> Tuple[Optional[Any], Optional[str]]:
|
||||
provider = _normalize_vision_provider(provider)
|
||||
if provider == "openrouter":
|
||||
return _try_openrouter()
|
||||
if provider == "nous":
|
||||
return _try_nous()
|
||||
if provider == "openai-codex":
|
||||
return _try_codex()
|
||||
if provider == "anthropic":
|
||||
return _try_anthropic()
|
||||
if provider == "custom":
|
||||
return _try_custom_endpoint()
|
||||
return None, None
|
||||
|
||||
|
||||
def get_async_vision_auxiliary_client():
|
||||
"""Return (async_client, model_slug) for async vision consumers.
|
||||
def _strict_vision_backend_available(provider: str) -> bool:
|
||||
return _resolve_strict_vision_backend(provider)[0] is not None
|
||||
|
||||
Properly handles Codex routing — unlike manually constructing
|
||||
AsyncOpenAI from a sync client, this preserves the Responses API
|
||||
adapter for Codex providers.
|
||||
|
||||
Returns (None, None) when no provider is available.
|
||||
def _preferred_main_vision_provider() -> Optional[str]:
|
||||
"""Return the selected main provider when it is also a supported vision backend."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
model_cfg = config.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
provider = _normalize_vision_provider(model_cfg.get("provider", ""))
|
||||
if provider in _VISION_AUTO_PROVIDER_ORDER:
|
||||
return provider
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_available_vision_backends() -> List[str]:
|
||||
"""Return the currently available vision backends in auto-selection order.
|
||||
|
||||
This is the single source of truth for setup, tool gating, and runtime
|
||||
auto-routing of vision tasks. The selected main provider is preferred when
|
||||
it is also a known-good vision backend; otherwise Hermes falls back through
|
||||
the standard conservative order.
|
||||
"""
|
||||
sync_client, model = get_vision_auxiliary_client()
|
||||
if sync_client is None:
|
||||
return None, None
|
||||
return _to_async_client(sync_client, model)
|
||||
ordered = list(_VISION_AUTO_PROVIDER_ORDER)
|
||||
preferred = _preferred_main_vision_provider()
|
||||
if preferred in ordered:
|
||||
ordered.remove(preferred)
|
||||
ordered.insert(0, preferred)
|
||||
return [provider for provider in ordered if _strict_vision_backend_available(provider)]
|
||||
|
||||
|
||||
def resolve_vision_provider_client(
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
*,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
async_mode: bool = False,
|
||||
) -> Tuple[Optional[str], Optional[Any], Optional[str]]:
|
||||
"""Resolve the client actually used for vision tasks.
|
||||
|
||||
Direct endpoint overrides take precedence over provider selection. Explicit
|
||||
provider overrides still use the generic provider router for non-standard
|
||||
backends, so users can intentionally force experimental providers. Auto mode
|
||||
stays conservative and only tries vision backends known to work today.
|
||||
"""
|
||||
requested, resolved_model, resolved_base_url, resolved_api_key = _resolve_task_provider_model(
|
||||
"vision", provider, model, base_url, api_key
|
||||
)
|
||||
requested = _normalize_vision_provider(requested)
|
||||
|
||||
def _finalize(resolved_provider: str, sync_client: Any, default_model: Optional[str]):
|
||||
if sync_client is None:
|
||||
return resolved_provider, None, None
|
||||
final_model = resolved_model or default_model
|
||||
if async_mode:
|
||||
async_client, async_model = _to_async_client(sync_client, final_model)
|
||||
return resolved_provider, async_client, async_model
|
||||
return resolved_provider, sync_client, final_model
|
||||
|
||||
if resolved_base_url:
|
||||
client, final_model = resolve_provider_client(
|
||||
"custom",
|
||||
model=resolved_model,
|
||||
async_mode=async_mode,
|
||||
explicit_base_url=resolved_base_url,
|
||||
explicit_api_key=resolved_api_key,
|
||||
)
|
||||
if client is None:
|
||||
return "custom", None, None
|
||||
return "custom", client, final_model
|
||||
|
||||
if requested == "auto":
|
||||
for candidate in get_available_vision_backends():
|
||||
sync_client, default_model = _resolve_strict_vision_backend(candidate)
|
||||
if sync_client is not None:
|
||||
return _finalize(candidate, sync_client, default_model)
|
||||
logger.debug("Auxiliary vision client: none available")
|
||||
return None, None, None
|
||||
|
||||
if requested in _VISION_AUTO_PROVIDER_ORDER:
|
||||
sync_client, default_model = _resolve_strict_vision_backend(requested)
|
||||
return _finalize(requested, sync_client, default_model)
|
||||
|
||||
client, final_model = _get_cached_client(requested, resolved_model, async_mode)
|
||||
if client is None:
|
||||
return requested, None, None
|
||||
return requested, client, final_model
|
||||
|
||||
|
||||
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks."""
|
||||
_, client, final_model = resolve_vision_provider_client(async_mode=False)
|
||||
return client, final_model
|
||||
|
||||
|
||||
def get_async_vision_auxiliary_client():
|
||||
"""Return (async_client, model_slug) for async vision consumers."""
|
||||
_, client, final_model = resolve_vision_provider_client(async_mode=True)
|
||||
return client, final_model
|
||||
|
||||
|
||||
def get_auxiliary_extra_body() -> dict:
|
||||
@@ -829,7 +1143,7 @@ def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
The Codex adapter translates max_tokens internally, so we use max_tokens
|
||||
for it as well.
|
||||
"""
|
||||
custom_base = os.getenv("OPENAI_BASE_URL", "")
|
||||
custom_base = _current_custom_base_url()
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
# Only use max_completion_tokens for direct OpenAI custom endpoints
|
||||
if (not or_key
|
||||
@@ -851,19 +1165,29 @@ def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
# Every auxiliary LLM consumer should use these instead of manually
|
||||
# constructing clients and calling .chat.completions.create().
|
||||
|
||||
# Client cache: (provider, async_mode) -> (client, default_model)
|
||||
# Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model)
|
||||
_client_cache: Dict[tuple, tuple] = {}
|
||||
|
||||
|
||||
def _get_cached_client(
|
||||
provider: str, model: str = None, async_mode: bool = False,
|
||||
provider: str,
|
||||
model: str = None,
|
||||
async_mode: bool = False,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> Tuple[Optional[Any], Optional[str]]:
|
||||
"""Get or create a cached client for the given provider."""
|
||||
cache_key = (provider, async_mode)
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "")
|
||||
if cache_key in _client_cache:
|
||||
cached_client, cached_default = _client_cache[cache_key]
|
||||
return cached_client, model or cached_default
|
||||
client, default_model = resolve_provider_client(provider, model, async_mode)
|
||||
client, default_model = resolve_provider_client(
|
||||
provider,
|
||||
model,
|
||||
async_mode,
|
||||
explicit_base_url=base_url,
|
||||
explicit_api_key=api_key,
|
||||
)
|
||||
if client is not None:
|
||||
_client_cache[cache_key] = (client, default_model)
|
||||
return client, model or default_model
|
||||
@@ -873,57 +1197,75 @@ def _resolve_task_provider_model(
|
||||
task: str = None,
|
||||
provider: str = None,
|
||||
model: str = None,
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> Tuple[str, Optional[str], Optional[str], Optional[str]]:
|
||||
"""Determine provider + model for a call.
|
||||
|
||||
Priority:
|
||||
1. Explicit provider/model args (always win)
|
||||
2. Env var overrides (AUXILIARY_{TASK}_PROVIDER, etc.)
|
||||
3. Config file (auxiliary.{task}.provider/model or compression.*)
|
||||
1. Explicit provider/model/base_url/api_key args (always win)
|
||||
2. Env var overrides (AUXILIARY_{TASK}_*, CONTEXT_{TASK}_*)
|
||||
3. Config file (auxiliary.{task}.* or compression.*)
|
||||
4. "auto" (full auto-detection chain)
|
||||
|
||||
Returns (provider, model) where model may be None (use provider default).
|
||||
Returns (provider, model, base_url, api_key) where model may be None
|
||||
(use provider default). When base_url is set, provider is forced to
|
||||
"custom" and the task uses that direct endpoint.
|
||||
"""
|
||||
if provider:
|
||||
return provider, model
|
||||
config = {}
|
||||
cfg_provider = None
|
||||
cfg_model = None
|
||||
cfg_base_url = None
|
||||
cfg_api_key = None
|
||||
|
||||
if task:
|
||||
# Check env var overrides first
|
||||
env_provider = _get_auxiliary_provider(task)
|
||||
if env_provider != "auto":
|
||||
# Check for env var model override too
|
||||
env_model = None
|
||||
for prefix in ("AUXILIARY_", "CONTEXT_"):
|
||||
val = os.getenv(f"{prefix}{task.upper()}_MODEL", "").strip()
|
||||
if val:
|
||||
env_model = val
|
||||
break
|
||||
return env_provider, model or env_model
|
||||
|
||||
# Read from config file
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
except ImportError:
|
||||
return "auto", model
|
||||
config = {}
|
||||
|
||||
# Check auxiliary.{task} section
|
||||
aux = config.get("auxiliary", {})
|
||||
task_config = aux.get(task, {})
|
||||
cfg_provider = task_config.get("provider", "").strip() or None
|
||||
cfg_model = task_config.get("model", "").strip() or None
|
||||
aux = config.get("auxiliary", {}) if isinstance(config, dict) else {}
|
||||
task_config = aux.get(task, {}) if isinstance(aux, dict) else {}
|
||||
if not isinstance(task_config, dict):
|
||||
task_config = {}
|
||||
cfg_provider = str(task_config.get("provider", "")).strip() or None
|
||||
cfg_model = str(task_config.get("model", "")).strip() or None
|
||||
cfg_base_url = str(task_config.get("base_url", "")).strip() or None
|
||||
cfg_api_key = str(task_config.get("api_key", "")).strip() or None
|
||||
|
||||
# Backwards compat: compression section has its own keys
|
||||
if task == "compression" and not cfg_provider:
|
||||
comp = config.get("compression", {})
|
||||
cfg_provider = comp.get("summary_provider", "").strip() or None
|
||||
cfg_model = cfg_model or comp.get("summary_model", "").strip() or None
|
||||
comp = config.get("compression", {}) if isinstance(config, dict) else {}
|
||||
if isinstance(comp, dict):
|
||||
cfg_provider = comp.get("summary_provider", "").strip() or None
|
||||
cfg_model = cfg_model or comp.get("summary_model", "").strip() or None
|
||||
|
||||
env_model = _get_auxiliary_env_override(task, "MODEL") if task else None
|
||||
resolved_model = model or env_model or cfg_model
|
||||
|
||||
if base_url:
|
||||
return "custom", resolved_model, base_url, api_key
|
||||
if provider:
|
||||
return provider, resolved_model, base_url, api_key
|
||||
|
||||
if task:
|
||||
env_base_url = _get_auxiliary_env_override(task, "BASE_URL")
|
||||
env_api_key = _get_auxiliary_env_override(task, "API_KEY")
|
||||
if env_base_url:
|
||||
return "custom", resolved_model, env_base_url, env_api_key or cfg_api_key
|
||||
|
||||
env_provider = _get_auxiliary_provider(task)
|
||||
if env_provider != "auto":
|
||||
return env_provider, resolved_model, None, None
|
||||
|
||||
if cfg_base_url:
|
||||
return "custom", resolved_model, cfg_base_url, cfg_api_key
|
||||
if cfg_provider and cfg_provider != "auto":
|
||||
return cfg_provider, model or cfg_model
|
||||
return "auto", model or cfg_model
|
||||
return cfg_provider, resolved_model, None, None
|
||||
return "auto", resolved_model, None, None
|
||||
|
||||
return "auto", model
|
||||
return "auto", resolved_model, None, None
|
||||
|
||||
|
||||
def _build_call_kwargs(
|
||||
@@ -935,6 +1277,7 @@ def _build_call_kwargs(
|
||||
tools: Optional[list] = None,
|
||||
timeout: float = 30.0,
|
||||
extra_body: Optional[dict] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Build kwargs for .chat.completions.create() with model/provider adjustments."""
|
||||
kwargs: Dict[str, Any] = {
|
||||
@@ -950,7 +1293,7 @@ def _build_call_kwargs(
|
||||
# Codex adapter handles max_tokens internally; OpenRouter/Nous use max_tokens.
|
||||
# Direct OpenAI api.openai.com with newer models needs max_completion_tokens.
|
||||
if provider == "custom":
|
||||
custom_base = os.getenv("OPENAI_BASE_URL", "")
|
||||
custom_base = base_url or _current_custom_base_url()
|
||||
if "api.openai.com" in custom_base.lower():
|
||||
kwargs["max_completion_tokens"] = max_tokens
|
||||
else:
|
||||
@@ -976,6 +1319,8 @@ def call_llm(
|
||||
*,
|
||||
provider: str = None,
|
||||
model: str = None,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
messages: list,
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
@@ -1007,26 +1352,57 @@ def call_llm(
|
||||
Raises:
|
||||
RuntimeError: If no provider is configured.
|
||||
"""
|
||||
resolved_provider, resolved_model = _resolve_task_provider_model(
|
||||
task, provider, model)
|
||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key = _resolve_task_provider_model(
|
||||
task, provider, model, base_url, api_key)
|
||||
|
||||
client, final_model = _get_cached_client(resolved_provider, resolved_model)
|
||||
if client is None:
|
||||
# Fallback: try openrouter
|
||||
if resolved_provider != "openrouter":
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
"openrouter", resolved_model or _OPENROUTER_MODEL)
|
||||
if client is None:
|
||||
raise RuntimeError(
|
||||
f"No LLM provider configured for task={task} provider={resolved_provider}. "
|
||||
f"Run: hermes setup")
|
||||
if task == "vision":
|
||||
effective_provider, client, final_model = resolve_vision_provider_client(
|
||||
provider=provider,
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
async_mode=False,
|
||||
)
|
||||
if client is None and resolved_provider != "auto" and not resolved_base_url:
|
||||
logger.warning(
|
||||
"Vision provider %s unavailable, falling back to auto vision backends",
|
||||
resolved_provider,
|
||||
)
|
||||
effective_provider, client, final_model = resolve_vision_provider_client(
|
||||
provider="auto",
|
||||
model=resolved_model,
|
||||
async_mode=False,
|
||||
)
|
||||
if client is None:
|
||||
raise RuntimeError(
|
||||
f"No LLM provider configured for task={task} provider={resolved_provider}. "
|
||||
f"Run: hermes setup"
|
||||
)
|
||||
resolved_provider = effective_provider or resolved_provider
|
||||
else:
|
||||
client, final_model = _get_cached_client(
|
||||
resolved_provider,
|
||||
resolved_model,
|
||||
base_url=resolved_base_url,
|
||||
api_key=resolved_api_key,
|
||||
)
|
||||
if client is None:
|
||||
# Fallback: try openrouter
|
||||
if resolved_provider != "openrouter" and not resolved_base_url:
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
"openrouter", resolved_model or _OPENROUTER_MODEL)
|
||||
if client is None:
|
||||
raise RuntimeError(
|
||||
f"No LLM provider configured for task={task} provider={resolved_provider}. "
|
||||
f"Run: hermes setup")
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
resolved_provider, final_model, messages,
|
||||
temperature=temperature, max_tokens=max_tokens,
|
||||
tools=tools, timeout=timeout, extra_body=extra_body)
|
||||
tools=tools, timeout=timeout, extra_body=extra_body,
|
||||
base_url=resolved_base_url)
|
||||
|
||||
# Handle max_tokens vs max_completion_tokens retry
|
||||
try:
|
||||
@@ -1045,6 +1421,8 @@ async def async_call_llm(
|
||||
*,
|
||||
provider: str = None,
|
||||
model: str = None,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
messages: list,
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
@@ -1056,27 +1434,58 @@ async def async_call_llm(
|
||||
|
||||
Same as call_llm() but async. See call_llm() for full documentation.
|
||||
"""
|
||||
resolved_provider, resolved_model = _resolve_task_provider_model(
|
||||
task, provider, model)
|
||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key = _resolve_task_provider_model(
|
||||
task, provider, model, base_url, api_key)
|
||||
|
||||
client, final_model = _get_cached_client(
|
||||
resolved_provider, resolved_model, async_mode=True)
|
||||
if client is None:
|
||||
if resolved_provider != "openrouter":
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
"openrouter", resolved_model or _OPENROUTER_MODEL,
|
||||
async_mode=True)
|
||||
if client is None:
|
||||
raise RuntimeError(
|
||||
f"No LLM provider configured for task={task} provider={resolved_provider}. "
|
||||
f"Run: hermes setup")
|
||||
if task == "vision":
|
||||
effective_provider, client, final_model = resolve_vision_provider_client(
|
||||
provider=provider,
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
async_mode=True,
|
||||
)
|
||||
if client is None and resolved_provider != "auto" and not resolved_base_url:
|
||||
logger.warning(
|
||||
"Vision provider %s unavailable, falling back to auto vision backends",
|
||||
resolved_provider,
|
||||
)
|
||||
effective_provider, client, final_model = resolve_vision_provider_client(
|
||||
provider="auto",
|
||||
model=resolved_model,
|
||||
async_mode=True,
|
||||
)
|
||||
if client is None:
|
||||
raise RuntimeError(
|
||||
f"No LLM provider configured for task={task} provider={resolved_provider}. "
|
||||
f"Run: hermes setup"
|
||||
)
|
||||
resolved_provider = effective_provider or resolved_provider
|
||||
else:
|
||||
client, final_model = _get_cached_client(
|
||||
resolved_provider,
|
||||
resolved_model,
|
||||
async_mode=True,
|
||||
base_url=resolved_base_url,
|
||||
api_key=resolved_api_key,
|
||||
)
|
||||
if client is None:
|
||||
if resolved_provider != "openrouter" and not resolved_base_url:
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
"openrouter", resolved_model or _OPENROUTER_MODEL,
|
||||
async_mode=True)
|
||||
if client is None:
|
||||
raise RuntimeError(
|
||||
f"No LLM provider configured for task={task} provider={resolved_provider}. "
|
||||
f"Run: hermes setup")
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
resolved_provider, final_model, messages,
|
||||
temperature=temperature, max_tokens=max_tokens,
|
||||
tools=tools, timeout=timeout, extra_body=extra_body)
|
||||
tools=tools, timeout=timeout, extra_body=extra_body,
|
||||
base_url=resolved_base_url)
|
||||
|
||||
try:
|
||||
return await client.chat.completions.create(**kwargs)
|
||||
|
||||
+10
-7
@@ -80,7 +80,7 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str | N
|
||||
"image_generate": "prompt", "text_to_speech": "text",
|
||||
"vision_analyze": "question", "mixture_of_agents": "user_prompt",
|
||||
"skill_view": "name", "skills_list": "category",
|
||||
"schedule_cronjob": "name",
|
||||
"cronjob": "action",
|
||||
"execute_code": "code", "delegate_task": "goal",
|
||||
"clarify": "question", "skill_manage": "name",
|
||||
}
|
||||
@@ -513,12 +513,15 @@ def get_cute_tool_message(
|
||||
return _wrap(f"┊ 🧠 reason {_trunc(args.get('user_prompt', ''), 30)} {dur}")
|
||||
if tool_name == "send_message":
|
||||
return _wrap(f"┊ 📨 send {args.get('target', '?')}: \"{_trunc(args.get('message', ''), 25)}\" {dur}")
|
||||
if tool_name == "schedule_cronjob":
|
||||
return _wrap(f"┊ ⏰ schedule {_trunc(args.get('name', args.get('prompt', 'task')), 30)} {dur}")
|
||||
if tool_name == "list_cronjobs":
|
||||
return _wrap(f"┊ ⏰ jobs listing {dur}")
|
||||
if tool_name == "remove_cronjob":
|
||||
return _wrap(f"┊ ⏰ remove job {args.get('job_id', '?')} {dur}")
|
||||
if tool_name == "cronjob":
|
||||
action = args.get("action", "?")
|
||||
if action == "create":
|
||||
skills = args.get("skills") or ([] if not args.get("skill") else [args.get("skill")])
|
||||
label = args.get("name") or (skills[0] if skills else None) or args.get("prompt", "task")
|
||||
return _wrap(f"┊ ⏰ cron create {_trunc(label, 24)} {dur}")
|
||||
if action == "list":
|
||||
return _wrap(f"┊ ⏰ cron listing {dur}")
|
||||
return _wrap(f"┊ ⏰ cron {action} {args.get('job_id', '')} {dur}")
|
||||
if tool_name.startswith("rl_"):
|
||||
rl = {
|
||||
"rl_list_environments": "list envs", "rl_select_environment": f"select {args.get('name', '')}",
|
||||
|
||||
+14
-5
@@ -71,15 +71,17 @@ DEFAULT_AGENT_IDENTITY = (
|
||||
)
|
||||
|
||||
MEMORY_GUIDANCE = (
|
||||
"You have persistent memory across sessions. Proactively save important things "
|
||||
"you learn (user preferences, environment details, useful approaches) and do "
|
||||
"(like a diary!) using the memory tool -- don't wait to be asked."
|
||||
"You have persistent memory across sessions. Save durable facts using the memory "
|
||||
"tool: user preferences, environment details, tool quirks, and stable conventions. "
|
||||
"Memory is injected into every turn, so keep it compact. Do NOT save task progress, "
|
||||
"session outcomes, or completed-work logs to memory; use session_search to recall "
|
||||
"those from past transcripts."
|
||||
)
|
||||
|
||||
SESSION_SEARCH_GUIDANCE = (
|
||||
"When the user references something from a past conversation or you suspect "
|
||||
"relevant prior context exists, use session_search to recall it before asking "
|
||||
"them to repeat themselves."
|
||||
"relevant cross-session context exists, use session_search to recall it before "
|
||||
"asking them to repeat themselves."
|
||||
)
|
||||
|
||||
SKILLS_GUIDANCE = (
|
||||
@@ -139,6 +141,13 @@ PLATFORM_HINTS = {
|
||||
"is preserved for threading. Do not include greetings or sign-offs unless "
|
||||
"contextually appropriate."
|
||||
),
|
||||
"cron": (
|
||||
"You are running as a scheduled cron job. Your final response is automatically "
|
||||
"delivered to the job's configured destination, so do not use send_message to "
|
||||
"send to that same target again. If you want the user to receive something in "
|
||||
"the scheduled destination, put it directly in your final response. Use "
|
||||
"send_message only for additional or different targets."
|
||||
),
|
||||
"cli": (
|
||||
"You are a CLI AI Agent. Try not to use markdown but simple text "
|
||||
"renderable inside a terminal."
|
||||
|
||||
+186
-67
@@ -1,17 +1,151 @@
|
||||
"""Skill slash commands — scan installed skills and build invocation messages.
|
||||
"""Shared slash command helpers for skills and built-in prompt-style modes.
|
||||
|
||||
Shared between CLI (cli.py) and gateway (gateway/run.py) so both surfaces
|
||||
can invoke skills via /skill-name commands.
|
||||
can invoke skills via /skill-name commands and prompt-only built-ins like
|
||||
/plan.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_skill_commands: Dict[str, Dict[str, Any]] = {}
|
||||
_PLAN_SLUG_RE = re.compile(r"[^a-z0-9]+")
|
||||
|
||||
|
||||
def build_plan_path(
|
||||
user_instruction: str = "",
|
||||
*,
|
||||
now: datetime | None = None,
|
||||
) -> Path:
|
||||
"""Return the default workspace-relative markdown path for a /plan invocation.
|
||||
|
||||
Relative paths are intentional: file tools are task/backend-aware and resolve
|
||||
them against the active working directory for local, docker, ssh, modal,
|
||||
daytona, and similar terminal backends. That keeps the plan with the active
|
||||
workspace instead of the Hermes host's global home directory.
|
||||
"""
|
||||
slug_source = (user_instruction or "").strip().splitlines()[0] if user_instruction else ""
|
||||
slug = _PLAN_SLUG_RE.sub("-", slug_source.lower()).strip("-")
|
||||
if slug:
|
||||
slug = "-".join(part for part in slug.split("-")[:8] if part)[:48].strip("-")
|
||||
slug = slug or "conversation-plan"
|
||||
timestamp = (now or datetime.now()).strftime("%Y-%m-%d_%H%M%S")
|
||||
return Path(".hermes") / "plans" / f"{timestamp}-{slug}.md"
|
||||
|
||||
|
||||
def _load_skill_payload(skill_identifier: str, task_id: str | None = None) -> tuple[dict[str, Any], Path | None, str] | None:
|
||||
"""Load a skill by name/path and return (loaded_payload, skill_dir, display_name)."""
|
||||
raw_identifier = (skill_identifier or "").strip()
|
||||
if not raw_identifier:
|
||||
return None
|
||||
|
||||
try:
|
||||
from tools.skills_tool import SKILLS_DIR, skill_view
|
||||
|
||||
identifier_path = Path(raw_identifier).expanduser()
|
||||
if identifier_path.is_absolute():
|
||||
try:
|
||||
normalized = str(identifier_path.resolve().relative_to(SKILLS_DIR.resolve()))
|
||||
except Exception:
|
||||
normalized = raw_identifier
|
||||
else:
|
||||
normalized = raw_identifier.lstrip("/")
|
||||
|
||||
loaded_skill = json.loads(skill_view(normalized, task_id=task_id))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not loaded_skill.get("success"):
|
||||
return None
|
||||
|
||||
skill_name = str(loaded_skill.get("name") or normalized)
|
||||
skill_path = str(loaded_skill.get("path") or "")
|
||||
skill_dir = None
|
||||
if skill_path:
|
||||
try:
|
||||
skill_dir = SKILLS_DIR / Path(skill_path).parent
|
||||
except Exception:
|
||||
skill_dir = None
|
||||
|
||||
return loaded_skill, skill_dir, skill_name
|
||||
|
||||
|
||||
def _build_skill_message(
|
||||
loaded_skill: dict[str, Any],
|
||||
skill_dir: Path | None,
|
||||
activation_note: str,
|
||||
user_instruction: str = "",
|
||||
runtime_note: str = "",
|
||||
) -> str:
|
||||
"""Format a loaded skill into a user/system message payload."""
|
||||
from tools.skills_tool import SKILLS_DIR
|
||||
|
||||
content = str(loaded_skill.get("content") or "")
|
||||
|
||||
parts = [activation_note, "", content.strip()]
|
||||
|
||||
if loaded_skill.get("setup_skipped"):
|
||||
parts.extend(
|
||||
[
|
||||
"",
|
||||
"[Skill setup note: Required environment setup was skipped. Continue loading the skill and explain any reduced functionality if it matters.]",
|
||||
]
|
||||
)
|
||||
elif loaded_skill.get("gateway_setup_hint"):
|
||||
parts.extend(
|
||||
[
|
||||
"",
|
||||
f"[Skill setup note: {loaded_skill['gateway_setup_hint']}]",
|
||||
]
|
||||
)
|
||||
elif loaded_skill.get("setup_needed") and loaded_skill.get("setup_note"):
|
||||
parts.extend(
|
||||
[
|
||||
"",
|
||||
f"[Skill setup note: {loaded_skill['setup_note']}]",
|
||||
]
|
||||
)
|
||||
|
||||
supporting = []
|
||||
linked_files = loaded_skill.get("linked_files") or {}
|
||||
for entries in linked_files.values():
|
||||
if isinstance(entries, list):
|
||||
supporting.extend(entries)
|
||||
|
||||
if not supporting and skill_dir:
|
||||
for subdir in ("references", "templates", "scripts", "assets"):
|
||||
subdir_path = skill_dir / subdir
|
||||
if subdir_path.exists():
|
||||
for f in sorted(subdir_path.rglob("*")):
|
||||
if f.is_file():
|
||||
rel = str(f.relative_to(skill_dir))
|
||||
supporting.append(rel)
|
||||
|
||||
if supporting and skill_dir:
|
||||
skill_view_target = str(skill_dir.relative_to(SKILLS_DIR))
|
||||
parts.append("")
|
||||
parts.append("[This skill has supporting files you can load with the skill_view tool:]")
|
||||
for sf in supporting:
|
||||
parts.append(f"- {sf}")
|
||||
parts.append(
|
||||
f'\nTo view any of these, use: skill_view(name="{skill_view_target}", file_path="<path>")'
|
||||
)
|
||||
|
||||
if user_instruction:
|
||||
parts.append("")
|
||||
parts.append(f"The user has provided the following instruction alongside the skill invocation: {user_instruction}")
|
||||
|
||||
if runtime_note:
|
||||
parts.append("")
|
||||
parts.append(f"[Runtime note: {runtime_note}]")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
@@ -68,6 +202,7 @@ def build_skill_invocation_message(
|
||||
cmd_key: str,
|
||||
user_instruction: str = "",
|
||||
task_id: str | None = None,
|
||||
runtime_note: str = "",
|
||||
) -> Optional[str]:
|
||||
"""Build the user message content for a skill slash command invocation.
|
||||
|
||||
@@ -83,77 +218,61 @@ def build_skill_invocation_message(
|
||||
if not skill_info:
|
||||
return None
|
||||
|
||||
skill_name = skill_info["name"]
|
||||
skill_path = skill_info["skill_dir"]
|
||||
loaded = _load_skill_payload(skill_info["skill_dir"], task_id=task_id)
|
||||
if not loaded:
|
||||
return f"[Failed to load skill: {skill_info['name']}]"
|
||||
|
||||
try:
|
||||
from tools.skills_tool import SKILLS_DIR, skill_view
|
||||
loaded_skill, skill_dir, skill_name = loaded
|
||||
activation_note = (
|
||||
f'[SYSTEM: The user has invoked the "{skill_name}" skill, indicating they want '
|
||||
"you to follow its instructions. The full skill content is loaded below.]"
|
||||
)
|
||||
return _build_skill_message(
|
||||
loaded_skill,
|
||||
skill_dir,
|
||||
activation_note,
|
||||
user_instruction=user_instruction,
|
||||
runtime_note=runtime_note,
|
||||
)
|
||||
|
||||
loaded_skill = json.loads(skill_view(skill_path, task_id=task_id))
|
||||
except Exception:
|
||||
return f"[Failed to load skill: {skill_name}]"
|
||||
|
||||
if not loaded_skill.get("success"):
|
||||
return f"[Failed to load skill: {skill_name}]"
|
||||
def build_preloaded_skills_prompt(
|
||||
skill_identifiers: list[str],
|
||||
task_id: str | None = None,
|
||||
) -> tuple[str, list[str], list[str]]:
|
||||
"""Load one or more skills for session-wide CLI preloading.
|
||||
|
||||
content = str(loaded_skill.get("content") or "")
|
||||
skill_dir = Path(skill_info["skill_dir"])
|
||||
Returns (prompt_text, loaded_skill_names, missing_identifiers).
|
||||
"""
|
||||
prompt_parts: list[str] = []
|
||||
loaded_names: list[str] = []
|
||||
missing: list[str] = []
|
||||
|
||||
parts = [
|
||||
f'[SYSTEM: The user has invoked the "{skill_name}" skill, indicating they want you to follow its instructions. The full skill content is loaded below.]',
|
||||
"",
|
||||
content.strip(),
|
||||
]
|
||||
seen: set[str] = set()
|
||||
for raw_identifier in skill_identifiers:
|
||||
identifier = (raw_identifier or "").strip()
|
||||
if not identifier or identifier in seen:
|
||||
continue
|
||||
seen.add(identifier)
|
||||
|
||||
if loaded_skill.get("setup_skipped"):
|
||||
parts.extend(
|
||||
[
|
||||
"",
|
||||
"[Skill setup note: Required environment setup was skipped. Continue loading the skill and explain any reduced functionality if it matters.]",
|
||||
]
|
||||
loaded = _load_skill_payload(identifier, task_id=task_id)
|
||||
if not loaded:
|
||||
missing.append(identifier)
|
||||
continue
|
||||
|
||||
loaded_skill, skill_dir, skill_name = loaded
|
||||
activation_note = (
|
||||
f'[SYSTEM: The user launched this CLI session with the "{skill_name}" skill '
|
||||
"preloaded. Treat its instructions as active guidance for the duration of this "
|
||||
"session unless the user overrides them.]"
|
||||
)
|
||||
elif loaded_skill.get("gateway_setup_hint"):
|
||||
parts.extend(
|
||||
[
|
||||
"",
|
||||
f"[Skill setup note: {loaded_skill['gateway_setup_hint']}]",
|
||||
]
|
||||
)
|
||||
elif loaded_skill.get("setup_needed") and loaded_skill.get("setup_note"):
|
||||
parts.extend(
|
||||
[
|
||||
"",
|
||||
f"[Skill setup note: {loaded_skill['setup_note']}]",
|
||||
]
|
||||
prompt_parts.append(
|
||||
_build_skill_message(
|
||||
loaded_skill,
|
||||
skill_dir,
|
||||
activation_note,
|
||||
)
|
||||
)
|
||||
loaded_names.append(skill_name)
|
||||
|
||||
supporting = []
|
||||
linked_files = loaded_skill.get("linked_files") or {}
|
||||
for entries in linked_files.values():
|
||||
if isinstance(entries, list):
|
||||
supporting.extend(entries)
|
||||
|
||||
if not supporting:
|
||||
for subdir in ("references", "templates", "scripts", "assets"):
|
||||
subdir_path = skill_dir / subdir
|
||||
if subdir_path.exists():
|
||||
for f in sorted(subdir_path.rglob("*")):
|
||||
if f.is_file():
|
||||
rel = str(f.relative_to(skill_dir))
|
||||
supporting.append(rel)
|
||||
|
||||
if supporting:
|
||||
skill_view_target = str(Path(skill_path).relative_to(SKILLS_DIR))
|
||||
parts.append("")
|
||||
parts.append("[This skill has supporting files you can load with the skill_view tool:]")
|
||||
for sf in supporting:
|
||||
parts.append(f"- {sf}")
|
||||
parts.append(
|
||||
f'\nTo view any of these, use: skill_view(name="{skill_view_target}", file_path="<path>")'
|
||||
)
|
||||
|
||||
if user_instruction:
|
||||
parts.append("")
|
||||
parts.append(f"The user has provided the following instruction alongside the skill invocation: {user_instruction}")
|
||||
|
||||
return "\n".join(parts)
|
||||
return "\n\n".join(prompt_parts), loaded_names, missing
|
||||
|
||||
@@ -107,6 +107,12 @@ terminal:
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# docker_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
# # Optional: explicitly forward selected env vars into Docker.
|
||||
# # These values come from your current shell first, then ~/.hermes/.env.
|
||||
# # Warning: anything forwarded here is visible to commands run in the container.
|
||||
# docker_forward_env:
|
||||
# - "GITHUB_TOKEN"
|
||||
# - "NPM_TOKEN"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Singularity/Apptainer container
|
||||
@@ -456,7 +462,7 @@ platform_toolsets:
|
||||
# moa - mixture_of_agents (requires OPENROUTER_API_KEY)
|
||||
# todo - todo (in-memory task planning, no deps)
|
||||
# tts - text_to_speech (Edge TTS free, or ELEVENLABS/OPENAI key)
|
||||
# cronjob - schedule_cronjob, list_cronjobs, remove_cronjob
|
||||
# cronjob - cronjob (create/list/update/pause/resume/run/remove scheduled tasks)
|
||||
# rl - rl_list_environments, rl_start_training, etc. (requires TINKER_API_KEY)
|
||||
#
|
||||
# PRESETS (curated bundles):
|
||||
|
||||
+8
-1
@@ -7,7 +7,8 @@ This module provides scheduled task execution, allowing the agent to:
|
||||
- Execute tasks in isolated sessions (no prior context)
|
||||
|
||||
Cron jobs are executed automatically by the gateway daemon:
|
||||
hermes gateway install # Install as system service (recommended)
|
||||
hermes gateway install # Install as a user service
|
||||
sudo hermes gateway install --system # Linux servers: boot-time system service
|
||||
hermes gateway # Or run in foreground
|
||||
|
||||
The gateway ticks the scheduler every 60 seconds. A file lock prevents
|
||||
@@ -20,6 +21,9 @@ from cron.jobs import (
|
||||
list_jobs,
|
||||
remove_job,
|
||||
update_job,
|
||||
pause_job,
|
||||
resume_job,
|
||||
trigger_job,
|
||||
JOBS_FILE,
|
||||
)
|
||||
from cron.scheduler import tick
|
||||
@@ -30,6 +34,9 @@ __all__ = [
|
||||
"list_jobs",
|
||||
"remove_job",
|
||||
"update_job",
|
||||
"pause_job",
|
||||
"resume_job",
|
||||
"trigger_job",
|
||||
"tick",
|
||||
"JOBS_FILE",
|
||||
]
|
||||
|
||||
+152
-25
@@ -32,6 +32,32 @@ JOBS_FILE = CRON_DIR / "jobs.json"
|
||||
OUTPUT_DIR = CRON_DIR / "output"
|
||||
|
||||
|
||||
def _normalize_skill_list(skill: Optional[str] = None, skills: Optional[Any] = None) -> List[str]:
|
||||
"""Normalize legacy/single-skill and multi-skill inputs into a unique ordered list."""
|
||||
if skills is None:
|
||||
raw_items = [skill] if skill else []
|
||||
elif isinstance(skills, str):
|
||||
raw_items = [skills]
|
||||
else:
|
||||
raw_items = list(skills)
|
||||
|
||||
normalized: List[str] = []
|
||||
for item in raw_items:
|
||||
text = str(item or "").strip()
|
||||
if text and text not in normalized:
|
||||
normalized.append(text)
|
||||
return normalized
|
||||
|
||||
|
||||
def _apply_skill_fields(job: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return a job dict with canonical `skills` and legacy `skill` fields aligned."""
|
||||
normalized = dict(job)
|
||||
skills = _normalize_skill_list(normalized.get("skill"), normalized.get("skills"))
|
||||
normalized["skills"] = skills
|
||||
normalized["skill"] = skills[0] if skills else None
|
||||
return normalized
|
||||
|
||||
|
||||
def _secure_dir(path: Path):
|
||||
"""Set directory to owner-only access (0700). No-op on Windows."""
|
||||
try:
|
||||
@@ -263,39 +289,63 @@ def create_job(
|
||||
name: Optional[str] = None,
|
||||
repeat: Optional[int] = None,
|
||||
deliver: Optional[str] = None,
|
||||
origin: Optional[Dict[str, Any]] = None
|
||||
origin: Optional[Dict[str, Any]] = None,
|
||||
skill: Optional[str] = None,
|
||||
skills: Optional[List[str]] = None,
|
||||
model: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new cron job.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: The prompt to run (must be self-contained)
|
||||
prompt: The prompt to run (must be self-contained, or a task instruction when skill is set)
|
||||
schedule: Schedule string (see parse_schedule)
|
||||
name: Optional friendly name
|
||||
repeat: How many times to run (None = forever, 1 = once)
|
||||
deliver: Where to deliver output ("origin", "local", "telegram", etc.)
|
||||
origin: Source info where job was created (for "origin" delivery)
|
||||
|
||||
skill: Optional legacy single skill name to load before running the prompt
|
||||
skills: Optional ordered list of skills to load before running the prompt
|
||||
model: Optional per-job model override
|
||||
provider: Optional per-job provider override
|
||||
base_url: Optional per-job base URL override
|
||||
|
||||
Returns:
|
||||
The created job dict
|
||||
"""
|
||||
parsed_schedule = parse_schedule(schedule)
|
||||
|
||||
|
||||
# Auto-set repeat=1 for one-shot schedules if not specified
|
||||
if parsed_schedule["kind"] == "once" and repeat is None:
|
||||
repeat = 1
|
||||
|
||||
|
||||
# Default delivery to origin if available, otherwise local
|
||||
if deliver is None:
|
||||
deliver = "origin" if origin else "local"
|
||||
|
||||
|
||||
job_id = uuid.uuid4().hex[:12]
|
||||
now = _hermes_now().isoformat()
|
||||
|
||||
|
||||
normalized_skills = _normalize_skill_list(skill, skills)
|
||||
normalized_model = str(model).strip() if isinstance(model, str) else None
|
||||
normalized_provider = str(provider).strip() if isinstance(provider, str) else None
|
||||
normalized_base_url = str(base_url).strip().rstrip("/") if isinstance(base_url, str) else None
|
||||
normalized_model = normalized_model or None
|
||||
normalized_provider = normalized_provider or None
|
||||
normalized_base_url = normalized_base_url or None
|
||||
|
||||
label_source = (prompt or (normalized_skills[0] if normalized_skills else None)) or "cron job"
|
||||
job = {
|
||||
"id": job_id,
|
||||
"name": name or prompt[:50].strip(),
|
||||
"name": name or label_source[:50].strip(),
|
||||
"prompt": prompt,
|
||||
"skills": normalized_skills,
|
||||
"skill": normalized_skills[0] if normalized_skills else None,
|
||||
"model": normalized_model,
|
||||
"provider": normalized_provider,
|
||||
"base_url": normalized_base_url,
|
||||
"schedule": parsed_schedule,
|
||||
"schedule_display": parsed_schedule.get("display", schedule),
|
||||
"repeat": {
|
||||
@@ -303,6 +353,9 @@ def create_job(
|
||||
"completed": 0
|
||||
},
|
||||
"enabled": True,
|
||||
"state": "scheduled",
|
||||
"paused_at": None,
|
||||
"paused_reason": None,
|
||||
"created_at": now,
|
||||
"next_run_at": compute_next_run(parsed_schedule),
|
||||
"last_run_at": None,
|
||||
@@ -312,11 +365,11 @@ def create_job(
|
||||
"deliver": deliver,
|
||||
"origin": origin, # Tracks where job was created for "origin" delivery
|
||||
}
|
||||
|
||||
|
||||
jobs = load_jobs()
|
||||
jobs.append(job)
|
||||
save_jobs(jobs)
|
||||
|
||||
|
||||
return job
|
||||
|
||||
|
||||
@@ -325,29 +378,100 @@ def get_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
jobs = load_jobs()
|
||||
for job in jobs:
|
||||
if job["id"] == job_id:
|
||||
return job
|
||||
return _apply_skill_fields(job)
|
||||
return None
|
||||
|
||||
|
||||
def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]:
|
||||
"""List all jobs, optionally including disabled ones."""
|
||||
jobs = load_jobs()
|
||||
jobs = [_apply_skill_fields(j) for j in load_jobs()]
|
||||
if not include_disabled:
|
||||
jobs = [j for j in jobs if j.get("enabled", True)]
|
||||
return jobs
|
||||
|
||||
|
||||
def update_job(job_id: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Update a job by ID."""
|
||||
"""Update a job by ID, refreshing derived schedule fields when needed."""
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
if job["id"] == job_id:
|
||||
jobs[i] = {**job, **updates}
|
||||
save_jobs(jobs)
|
||||
return jobs[i]
|
||||
if job["id"] != job_id:
|
||||
continue
|
||||
|
||||
updated = _apply_skill_fields({**job, **updates})
|
||||
schedule_changed = "schedule" in updates
|
||||
|
||||
if "skills" in updates or "skill" in updates:
|
||||
normalized_skills = _normalize_skill_list(updated.get("skill"), updated.get("skills"))
|
||||
updated["skills"] = normalized_skills
|
||||
updated["skill"] = normalized_skills[0] if normalized_skills else None
|
||||
|
||||
if schedule_changed:
|
||||
updated_schedule = updated["schedule"]
|
||||
updated["schedule_display"] = updates.get(
|
||||
"schedule_display",
|
||||
updated_schedule.get("display", updated.get("schedule_display")),
|
||||
)
|
||||
if updated.get("state") != "paused":
|
||||
updated["next_run_at"] = compute_next_run(updated_schedule)
|
||||
|
||||
if updated.get("enabled", True) and updated.get("state") != "paused" and not updated.get("next_run_at"):
|
||||
updated["next_run_at"] = compute_next_run(updated["schedule"])
|
||||
|
||||
jobs[i] = updated
|
||||
save_jobs(jobs)
|
||||
return _apply_skill_fields(jobs[i])
|
||||
return None
|
||||
|
||||
|
||||
def pause_job(job_id: str, reason: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Pause a job without deleting it."""
|
||||
return update_job(
|
||||
job_id,
|
||||
{
|
||||
"enabled": False,
|
||||
"state": "paused",
|
||||
"paused_at": _hermes_now().isoformat(),
|
||||
"paused_reason": reason,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resume_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Resume a paused job and compute the next future run from now."""
|
||||
job = get_job(job_id)
|
||||
if not job:
|
||||
return None
|
||||
|
||||
next_run_at = compute_next_run(job["schedule"])
|
||||
return update_job(
|
||||
job_id,
|
||||
{
|
||||
"enabled": True,
|
||||
"state": "scheduled",
|
||||
"paused_at": None,
|
||||
"paused_reason": None,
|
||||
"next_run_at": next_run_at,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def trigger_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Schedule a job to run on the next scheduler tick."""
|
||||
job = get_job(job_id)
|
||||
if not job:
|
||||
return None
|
||||
return update_job(
|
||||
job_id,
|
||||
{
|
||||
"enabled": True,
|
||||
"state": "scheduled",
|
||||
"paused_at": None,
|
||||
"paused_reason": None,
|
||||
"next_run_at": _hermes_now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def remove_job(job_id: str) -> bool:
|
||||
"""Remove a job by ID."""
|
||||
jobs = load_jobs()
|
||||
@@ -389,11 +513,14 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
|
||||
# Compute next run
|
||||
job["next_run_at"] = compute_next_run(job["schedule"], now)
|
||||
|
||||
|
||||
# If no next run (one-shot completed), disable
|
||||
if job["next_run_at"] is None:
|
||||
job["enabled"] = False
|
||||
|
||||
job["state"] = "completed"
|
||||
elif job.get("state") != "paused":
|
||||
job["state"] = "scheduled"
|
||||
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
@@ -403,21 +530,21 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now."""
|
||||
now = _hermes_now()
|
||||
jobs = load_jobs()
|
||||
jobs = [_apply_skill_fields(j) for j in load_jobs()]
|
||||
due = []
|
||||
|
||||
|
||||
for job in jobs:
|
||||
if not job.get("enabled", True):
|
||||
continue
|
||||
|
||||
|
||||
next_run = job.get("next_run_at")
|
||||
if not next_run:
|
||||
continue
|
||||
|
||||
|
||||
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
|
||||
if next_run_dt <= now:
|
||||
due.append(job)
|
||||
|
||||
|
||||
return due
|
||||
|
||||
|
||||
|
||||
+123
-39
@@ -9,6 +9,7 @@ runs at a time if multiple processes overlap.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -56,6 +57,50 @@ def _resolve_origin(job: dict) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
"""Resolve the concrete auto-delivery target for a cron job, if any."""
|
||||
deliver = job.get("deliver", "local")
|
||||
origin = _resolve_origin(job)
|
||||
|
||||
if deliver == "local":
|
||||
return None
|
||||
|
||||
if deliver == "origin":
|
||||
if not origin:
|
||||
return None
|
||||
return {
|
||||
"platform": origin["platform"],
|
||||
"chat_id": str(origin["chat_id"]),
|
||||
"thread_id": origin.get("thread_id"),
|
||||
}
|
||||
|
||||
if ":" in deliver:
|
||||
platform_name, chat_id = deliver.split(":", 1)
|
||||
return {
|
||||
"platform": platform_name,
|
||||
"chat_id": chat_id,
|
||||
"thread_id": None,
|
||||
}
|
||||
|
||||
platform_name = deliver
|
||||
if origin and origin.get("platform") == platform_name:
|
||||
return {
|
||||
"platform": platform_name,
|
||||
"chat_id": str(origin["chat_id"]),
|
||||
"thread_id": origin.get("thread_id"),
|
||||
}
|
||||
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
if not chat_id:
|
||||
return None
|
||||
|
||||
return {
|
||||
"platform": platform_name,
|
||||
"chat_id": chat_id,
|
||||
"thread_id": None,
|
||||
}
|
||||
|
||||
|
||||
def _deliver_result(job: dict, content: str) -> None:
|
||||
"""
|
||||
Deliver job output to the configured target (origin chat, specific platform, etc.).
|
||||
@@ -63,36 +108,19 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
Uses the standalone platform send functions from send_message_tool so delivery
|
||||
works whether or not the gateway is running.
|
||||
"""
|
||||
deliver = job.get("deliver", "local")
|
||||
origin = _resolve_origin(job)
|
||||
|
||||
if deliver == "local":
|
||||
target = _resolve_delivery_target(job)
|
||||
if not target:
|
||||
if job.get("deliver", "local") != "local":
|
||||
logger.warning(
|
||||
"Job '%s' deliver=%s but no concrete delivery target could be resolved",
|
||||
job["id"],
|
||||
job.get("deliver", "local"),
|
||||
)
|
||||
return
|
||||
|
||||
thread_id = None
|
||||
|
||||
# Resolve target platform + chat_id
|
||||
if deliver == "origin":
|
||||
if not origin:
|
||||
logger.warning("Job '%s' deliver=origin but no origin stored, skipping delivery", job["id"])
|
||||
return
|
||||
platform_name = origin["platform"]
|
||||
chat_id = origin["chat_id"]
|
||||
thread_id = origin.get("thread_id")
|
||||
elif ":" in deliver:
|
||||
platform_name, chat_id = deliver.split(":", 1)
|
||||
else:
|
||||
# Bare platform name like "telegram" — need to resolve to origin or home channel
|
||||
platform_name = deliver
|
||||
if origin and origin.get("platform") == platform_name:
|
||||
chat_id = origin["chat_id"]
|
||||
thread_id = origin.get("thread_id")
|
||||
else:
|
||||
# Fall back to home channel
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
if not chat_id:
|
||||
logger.warning("Job '%s' deliver=%s but no chat_id or home channel. Set via: hermes config set %s_HOME_CHANNEL <channel_id>", job["id"], deliver, platform_name.upper())
|
||||
return
|
||||
platform_name = target["platform"]
|
||||
chat_id = target["chat_id"]
|
||||
thread_id = target.get("thread_id")
|
||||
|
||||
from tools.send_message_tool import _send_to_platform
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
@@ -147,6 +175,43 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
logger.warning("Job '%s': mirror_to_session failed: %s", job["id"], e)
|
||||
|
||||
|
||||
def _build_job_prompt(job: dict) -> str:
|
||||
"""Build the effective prompt for a cron job, optionally loading one or more skills first."""
|
||||
prompt = job.get("prompt", "")
|
||||
skills = job.get("skills")
|
||||
if skills is None:
|
||||
legacy = job.get("skill")
|
||||
skills = [legacy] if legacy else []
|
||||
|
||||
skill_names = [str(name).strip() for name in skills if str(name).strip()]
|
||||
if not skill_names:
|
||||
return prompt
|
||||
|
||||
from tools.skills_tool import skill_view
|
||||
|
||||
parts = []
|
||||
for skill_name in skill_names:
|
||||
loaded = json.loads(skill_view(skill_name))
|
||||
if not loaded.get("success"):
|
||||
error = loaded.get("error") or f"Failed to load skill '{skill_name}'"
|
||||
raise RuntimeError(error)
|
||||
|
||||
content = str(loaded.get("content") or "").strip()
|
||||
if parts:
|
||||
parts.append("")
|
||||
parts.extend(
|
||||
[
|
||||
f'[SYSTEM: The user has invoked the "{skill_name}" skill, indicating they want you to follow its instructions. The full skill content is loaded below.]',
|
||||
"",
|
||||
content,
|
||||
]
|
||||
)
|
||||
|
||||
if prompt:
|
||||
parts.extend(["", f"The user has provided the following instruction alongside the skill invocation: {prompt}"])
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
"""
|
||||
Execute a single cron job.
|
||||
@@ -167,9 +232,9 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
job_id = job["id"]
|
||||
job_name = job["name"]
|
||||
prompt = job["prompt"]
|
||||
prompt = _build_job_prompt(job)
|
||||
origin = _resolve_origin(job)
|
||||
|
||||
|
||||
logger.info("Running job '%s' (ID: %s)", job_name, job_id)
|
||||
logger.info("Prompt: %s", prompt[:100])
|
||||
|
||||
@@ -189,7 +254,14 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(str(_hermes_home / ".env"), override=True, encoding="latin-1")
|
||||
|
||||
model = os.getenv("HERMES_MODEL") or "anthropic/claude-opus-4.6"
|
||||
delivery_target = _resolve_delivery_target(job)
|
||||
if delivery_target:
|
||||
os.environ["HERMES_CRON_AUTO_DELIVER_PLATFORM"] = delivery_target["platform"]
|
||||
os.environ["HERMES_CRON_AUTO_DELIVER_CHAT_ID"] = str(delivery_target["chat_id"])
|
||||
if delivery_target.get("thread_id") is not None:
|
||||
os.environ["HERMES_CRON_AUTO_DELIVER_THREAD_ID"] = str(delivery_target["thread_id"])
|
||||
|
||||
model = job.get("model") or os.getenv("HERMES_MODEL") or "anthropic/claude-opus-4.6"
|
||||
|
||||
# Load config.yaml for model, reasoning, prefill, toolsets, provider routing
|
||||
_cfg = {}
|
||||
@@ -200,10 +272,11 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
with open(_cfg_path) as _f:
|
||||
_cfg = yaml.safe_load(_f) or {}
|
||||
_model_cfg = _cfg.get("model", {})
|
||||
if isinstance(_model_cfg, str):
|
||||
model = _model_cfg
|
||||
elif isinstance(_model_cfg, dict):
|
||||
model = _model_cfg.get("default", model)
|
||||
if not job.get("model"):
|
||||
if isinstance(_model_cfg, str):
|
||||
model = _model_cfg
|
||||
elif isinstance(_model_cfg, dict):
|
||||
model = _model_cfg.get("default", model)
|
||||
except Exception as e:
|
||||
logger.warning("Job '%s': failed to load config.yaml, using defaults: %s", job_id, e)
|
||||
|
||||
@@ -248,9 +321,12 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
format_runtime_provider_error,
|
||||
)
|
||||
try:
|
||||
runtime = resolve_runtime_provider(
|
||||
requested=os.getenv("HERMES_INFERENCE_PROVIDER"),
|
||||
)
|
||||
runtime_kwargs = {
|
||||
"requested": job.get("provider") or os.getenv("HERMES_INFERENCE_PROVIDER"),
|
||||
}
|
||||
if job.get("base_url"):
|
||||
runtime_kwargs["explicit_base_url"] = job.get("base_url")
|
||||
runtime = resolve_runtime_provider(**runtime_kwargs)
|
||||
except Exception as exc:
|
||||
message = format_runtime_provider_error(exc)
|
||||
raise RuntimeError(message) from exc
|
||||
@@ -268,6 +344,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
disabled_toolsets=["cronjob"],
|
||||
quiet_mode=True,
|
||||
platform="cron",
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}",
|
||||
@@ -324,7 +401,14 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
finally:
|
||||
# Clean up injected env vars so they don't leak to other jobs
|
||||
for key in ("HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME"):
|
||||
for key in (
|
||||
"HERMES_SESSION_PLATFORM",
|
||||
"HERMES_SESSION_CHAT_ID",
|
||||
"HERMES_SESSION_CHAT_NAME",
|
||||
"HERMES_CRON_AUTO_DELIVER_PLATFORM",
|
||||
"HERMES_CRON_AUTO_DELIVER_CHAT_ID",
|
||||
"HERMES_CRON_AUTO_DELIVER_THREAD_ID",
|
||||
):
|
||||
os.environ.pop(key, None)
|
||||
if _session_db:
|
||||
try:
|
||||
|
||||
@@ -39,7 +39,9 @@ def resize_tool_pool(max_workers: int):
|
||||
Safe to call before any tasks are submitted.
|
||||
"""
|
||||
global _tool_executor
|
||||
old_executor = _tool_executor
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
old_executor.shutdown(wait=False)
|
||||
logger.info("Tool thread pool resized to %d workers", max_workers)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,12 +10,13 @@ Format uses special unicode tokens:
|
||||
<|tool▁call▁end|>
|
||||
<|tool▁calls▁end|>
|
||||
|
||||
Based on VLLM's DeepSeekV3ToolParser.extract_tool_calls()
|
||||
Fixes Issue #989: Support for multiple simultaneous tool calls.
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
@@ -24,6 +25,7 @@ from openai.types.chat.chat_completion_message_tool_call import (
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@register_parser("deepseek_v3")
|
||||
class DeepSeekV3ToolCallParser(ToolCallParser):
|
||||
@@ -32,45 +34,56 @@ class DeepSeekV3ToolCallParser(ToolCallParser):
|
||||
|
||||
Uses special unicode tokens with fullwidth angle brackets and block elements.
|
||||
Extracts type, function name, and JSON arguments from the structured format.
|
||||
Ensures all tool calls are captured when the model executes multiple actions.
|
||||
"""
|
||||
|
||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
||||
|
||||
# Regex captures: type, function_name, function_arguments
|
||||
# Updated PATTERN: Using \s* instead of literal \n for increased robustness
|
||||
# against variations in model formatting (Issue #989).
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<type>.*?)<|tool▁sep|>(?P<function_name>.*?)\n```json\n(?P<function_arguments>.*?)\n```<|tool▁call▁end|>",
|
||||
r"<|tool▁call▁begin|>(?P<type>.*?)<|tool▁sep|>(?P<function_name>.*?)\s*```json\s*(?P<function_arguments>.*?)\s*```\s*<|tool▁call▁end|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
"""
|
||||
Parses the input text and extracts all available tool calls.
|
||||
"""
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
# Using finditer to capture ALL tool calls in the sequence
|
||||
matches = list(self.PATTERN.finditer(text))
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
for match in matches:
|
||||
tc_type, func_name, func_args = match
|
||||
func_name = match.group("function_name").strip()
|
||||
func_args = match.group("function_arguments").strip()
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name.strip(),
|
||||
arguments=func_args.strip(),
|
||||
name=func_name,
|
||||
arguments=func_args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
if tool_calls:
|
||||
# Content is text before the first tool call block
|
||||
content_index = text.find(self.START_TOKEN)
|
||||
content = text[:content_index].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
# Content is everything before the tool calls section
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing DeepSeek V3 tool calls: {e}")
|
||||
return text, None
|
||||
|
||||
@@ -21,6 +21,17 @@ from hermes_cli.config import get_hermes_home
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool = True) -> bool:
|
||||
"""Coerce bool-ish config values, preserving a caller-provided default."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in ("true", "1", "yes", "on")
|
||||
return bool(value)
|
||||
|
||||
|
||||
class Platform(Enum):
|
||||
"""Supported messaging platforms."""
|
||||
LOCAL = "local"
|
||||
@@ -160,6 +171,9 @@ class GatewayConfig:
|
||||
|
||||
# Delivery settings
|
||||
always_log_local: bool = True # Always save cron outputs to local files
|
||||
|
||||
# STT settings
|
||||
stt_enabled: bool = True # Whether to auto-transcribe inbound voice messages
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
@@ -224,6 +238,7 @@ class GatewayConfig:
|
||||
"quick_commands": self.quick_commands,
|
||||
"sessions_dir": str(self.sessions_dir),
|
||||
"always_log_local": self.always_log_local,
|
||||
"stt_enabled": self.stt_enabled,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -260,6 +275,10 @@ class GatewayConfig:
|
||||
if not isinstance(quick_commands, dict):
|
||||
quick_commands = {}
|
||||
|
||||
stt_enabled = data.get("stt_enabled")
|
||||
if stt_enabled is None:
|
||||
stt_enabled = data.get("stt", {}).get("enabled") if isinstance(data.get("stt"), dict) else None
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
@@ -269,6 +288,7 @@ class GatewayConfig:
|
||||
quick_commands=quick_commands,
|
||||
sessions_dir=sessions_dir,
|
||||
always_log_local=data.get("always_log_local", True),
|
||||
stt_enabled=_coerce_bool(stt_enabled, True),
|
||||
)
|
||||
|
||||
|
||||
@@ -318,6 +338,12 @@ def load_gateway_config() -> GatewayConfig:
|
||||
else:
|
||||
logger.warning("Ignoring invalid quick_commands in config.yaml (expected mapping, got %s)", type(qc).__name__)
|
||||
|
||||
# Bridge STT enable/disable from config.yaml into gateway runtime.
|
||||
# This keeps the gateway aligned with the user-facing config source.
|
||||
stt_cfg = yaml_cfg.get("stt")
|
||||
if isinstance(stt_cfg, dict) and "enabled" in stt_cfg:
|
||||
config.stt_enabled = _coerce_bool(stt_cfg.get("enabled"), True)
|
||||
|
||||
# Bridge discord settings from config.yaml to env vars
|
||||
# (env vars take precedence — only set if not already defined)
|
||||
discord_cfg = yaml_cfg.get("discord", {})
|
||||
|
||||
+2
-2
@@ -161,7 +161,7 @@ class DeliveryRouter:
|
||||
|
||||
# Always include local if configured
|
||||
if self.config.always_log_local:
|
||||
local_key = (Platform.LOCAL, None)
|
||||
local_key = (Platform.LOCAL, None, None)
|
||||
if local_key not in seen_platforms:
|
||||
targets.append(DeliveryTarget(platform=Platform.LOCAL))
|
||||
|
||||
@@ -315,7 +315,7 @@ def build_delivery_context_for_tool(
|
||||
origin: Optional[SessionSource] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build context for the schedule_cronjob tool to understand delivery options.
|
||||
Build context for the unified cronjob tool to understand delivery options.
|
||||
|
||||
This is passed to the tool so it can validate and explain delivery targets.
|
||||
"""
|
||||
|
||||
@@ -173,7 +173,7 @@ platform_map = {
|
||||
}
|
||||
```
|
||||
|
||||
Without this, `schedule_cronjob(deliver="your_platform")` silently fails.
|
||||
Without this, `cronjob(action="create", deliver="your_platform", ...)` silently fails.
|
||||
|
||||
---
|
||||
|
||||
|
||||
+116
-2
@@ -288,6 +288,7 @@ class MessageEvent:
|
||||
message_id: Optional[str] = None
|
||||
|
||||
# Media attachments
|
||||
# media_urls: local file paths (for vision tool access)
|
||||
media_urls: List[str] = field(default_factory=list)
|
||||
media_types: List[str] = field(default_factory=list)
|
||||
|
||||
@@ -346,13 +347,85 @@ class BasePlatformAdapter(ABC):
|
||||
self.platform = platform
|
||||
self._message_handler: Optional[MessageHandler] = None
|
||||
self._running = False
|
||||
self._fatal_error_code: Optional[str] = None
|
||||
self._fatal_error_message: Optional[str] = None
|
||||
self._fatal_error_retryable = True
|
||||
self._fatal_error_handler: Optional[Callable[["BasePlatformAdapter"], Awaitable[None] | None]] = None
|
||||
|
||||
# Track active message handlers per session for interrupt support
|
||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||
# Background message-processing tasks spawned by handle_message().
|
||||
# Gateway shutdown cancels these so an old gateway instance doesn't keep
|
||||
# working on a task after --replace or manual restarts.
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||
self._auto_tts_disabled_chats: set = set()
|
||||
|
||||
@property
|
||||
def has_fatal_error(self) -> bool:
|
||||
return self._fatal_error_message is not None
|
||||
|
||||
@property
|
||||
def fatal_error_message(self) -> Optional[str]:
|
||||
return self._fatal_error_message
|
||||
|
||||
@property
|
||||
def fatal_error_code(self) -> Optional[str]:
|
||||
return self._fatal_error_code
|
||||
|
||||
@property
|
||||
def fatal_error_retryable(self) -> bool:
|
||||
return self._fatal_error_retryable
|
||||
|
||||
def set_fatal_error_handler(self, handler: Callable[["BasePlatformAdapter"], Awaitable[None] | None]) -> None:
|
||||
self._fatal_error_handler = handler
|
||||
|
||||
def _mark_connected(self) -> None:
|
||||
self._running = True
|
||||
self._fatal_error_code = None
|
||||
self._fatal_error_message = None
|
||||
self._fatal_error_retryable = True
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(platform=self.platform.value, platform_state="connected", error_code=None, error_message=None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _mark_disconnected(self) -> None:
|
||||
self._running = False
|
||||
if self.has_fatal_error:
|
||||
return
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(platform=self.platform.value, platform_state="disconnected", error_code=None, error_message=None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _set_fatal_error(self, code: str, message: str, *, retryable: bool) -> None:
|
||||
self._running = False
|
||||
self._fatal_error_code = code
|
||||
self._fatal_error_message = message
|
||||
self._fatal_error_retryable = retryable
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(
|
||||
platform=self.platform.value,
|
||||
platform_state="fatal",
|
||||
error_code=code,
|
||||
error_message=message,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _notify_fatal_error(self) -> None:
|
||||
handler = self._fatal_error_handler
|
||||
if not handler:
|
||||
return
|
||||
result = handler(self)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -683,7 +756,25 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
# Check if there's already an active handler for this session
|
||||
if session_key in self._active_sessions:
|
||||
# Store this as a pending message - it will interrupt the running agent
|
||||
# Special case: photo bursts/albums frequently arrive as multiple near-
|
||||
# simultaneous messages. Queue them without interrupting the active run,
|
||||
# then process them immediately after the current task finishes.
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
print(f"[{self.name}] 🖼️ Queuing photo follow-up for session {session_key} without interrupt")
|
||||
existing = self._pending_messages.get(session_key)
|
||||
if existing and existing.message_type == MessageType.PHOTO:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
else:
|
||||
self._pending_messages[session_key] = event
|
||||
return # Don't interrupt now - will run after current task completes
|
||||
|
||||
# Default behavior for non-photo follow-ups: interrupt the running agent
|
||||
print(f"[{self.name}] ⚡ New message while session {session_key} is active - triggering interrupt")
|
||||
self._pending_messages[session_key] = event
|
||||
# Signal the interrupt (the processing task checks this)
|
||||
@@ -691,7 +782,15 @@ class BasePlatformAdapter(ABC):
|
||||
return # Don't process now - will be handled after current task finishes
|
||||
|
||||
# Spawn background task to process this message
|
||||
asyncio.create_task(self._process_message_background(event, session_key))
|
||||
task = asyncio.create_task(self._process_message_background(event, session_key))
|
||||
try:
|
||||
self._background_tasks.add(task)
|
||||
except TypeError:
|
||||
# Some tests stub create_task() with lightweight sentinels that are not
|
||||
# hashable and do not support lifecycle callbacks.
|
||||
return
|
||||
if hasattr(task, "add_done_callback"):
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
@staticmethod
|
||||
def _get_human_delay() -> float:
|
||||
@@ -901,6 +1000,21 @@ class BasePlatformAdapter(ABC):
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
|
||||
async def cancel_background_tasks(self) -> None:
|
||||
"""Cancel any in-flight background message-processing tasks.
|
||||
|
||||
Used during gateway shutdown/replacement so active sessions from the old
|
||||
process do not keep running after adapters are being torn down.
|
||||
"""
|
||||
tasks = [task for task in self._background_tasks if not task.done()]
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
self._pending_messages.clear()
|
||||
self._active_sessions.clear()
|
||||
|
||||
def has_pending_interrupt(self, session_key: str) -> bool:
|
||||
"""Check if there's a pending interrupt for a session."""
|
||||
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
||||
|
||||
+156
-36
@@ -87,8 +87,9 @@ class VoiceReceiver:
|
||||
SAMPLE_RATE = 48000 # Discord native rate
|
||||
CHANNELS = 2 # Discord sends stereo
|
||||
|
||||
def __init__(self, voice_client):
|
||||
def __init__(self, voice_client, allowed_user_ids: set = None):
|
||||
self._vc = voice_client
|
||||
self._allowed_user_ids = allowed_user_ids or set()
|
||||
self._running = False
|
||||
|
||||
# Decryption
|
||||
@@ -274,19 +275,21 @@ class VoiceReceiver:
|
||||
if self._dave_session:
|
||||
with self._lock:
|
||||
user_id = self._ssrc_to_user.get(ssrc, 0)
|
||||
if user_id == 0:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE skip: unknown user for ssrc=%d", ssrc)
|
||||
return # unknown user, can't DAVE-decrypt
|
||||
try:
|
||||
import davey
|
||||
decrypted = self._dave_session.decrypt(
|
||||
user_id, davey.MediaType.audio, decrypted
|
||||
)
|
||||
except Exception as e:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e)
|
||||
return
|
||||
if user_id:
|
||||
try:
|
||||
import davey
|
||||
decrypted = self._dave_session.decrypt(
|
||||
user_id, davey.MediaType.audio, decrypted
|
||||
)
|
||||
except Exception as e:
|
||||
# Unencrypted passthrough — use NaCl-decrypted data as-is
|
||||
if "Unencrypted" not in str(e):
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e)
|
||||
return
|
||||
# If SSRC unknown (no SPEAKING event yet), skip DAVE and try
|
||||
# Opus decode directly — audio may be in passthrough mode.
|
||||
# Buffer will get a user_id when SPEAKING event arrives later.
|
||||
|
||||
# --- Opus decode -> PCM ---
|
||||
try:
|
||||
@@ -304,6 +307,32 @@ class VoiceReceiver:
|
||||
# Silence detection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _infer_user_for_ssrc(self, ssrc: int) -> int:
|
||||
"""Try to infer user_id for an unmapped SSRC.
|
||||
|
||||
When the bot rejoins a voice channel, Discord may not resend
|
||||
SPEAKING events for users already speaking. If exactly one
|
||||
allowed user is in the channel, map the SSRC to them.
|
||||
"""
|
||||
try:
|
||||
channel = self._vc.channel
|
||||
if not channel:
|
||||
return 0
|
||||
bot_id = self._vc.user.id if self._vc.user else 0
|
||||
allowed = self._allowed_user_ids
|
||||
candidates = [
|
||||
m.id for m in channel.members
|
||||
if m.id != bot_id and (not allowed or str(m.id) in allowed)
|
||||
]
|
||||
if len(candidates) == 1:
|
||||
uid = candidates[0]
|
||||
self._ssrc_to_user[ssrc] = uid
|
||||
logger.info("Auto-mapped ssrc=%d -> user=%d (sole allowed member)", ssrc, uid)
|
||||
return uid
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
def check_silence(self) -> list:
|
||||
"""Return list of (user_id, pcm_bytes) for completed utterances."""
|
||||
now = time.monotonic()
|
||||
@@ -322,6 +351,10 @@ class VoiceReceiver:
|
||||
|
||||
if silence_duration >= self.SILENCE_THRESHOLD and buf_duration >= self.MIN_SPEECH_DURATION:
|
||||
user_id = ssrc_user_map.get(ssrc, 0)
|
||||
if not user_id:
|
||||
# SSRC not mapped (SPEAKING event missing after bot rejoin).
|
||||
# Infer from allowed users in the voice channel.
|
||||
user_id = self._infer_user_for_ssrc(ssrc)
|
||||
if user_id:
|
||||
completed.append((user_id, bytes(buf)))
|
||||
self._buffers[ssrc] = bytearray()
|
||||
@@ -400,6 +433,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._voice_listen_tasks: Dict[int, asyncio.Task] = {} # guild_id -> listen loop
|
||||
self._voice_input_callback: Optional[Callable] = None # set by run.py
|
||||
self._on_voice_disconnect: Optional[Callable] = None # set by run.py
|
||||
# Track threads where the bot has participated so follow-up messages
|
||||
# in those threads don't require @mention.
|
||||
self._bot_participated_threads: set = set()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
@@ -580,7 +616,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
"""Send a message to a Discord channel."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Get the channel
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
@@ -605,10 +641,30 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
logger.debug("Could not fetch reply-to message: %s", e)
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
reference=reference if i == 0 else None,
|
||||
)
|
||||
chunk_reference = reference if i == 0 else None
|
||||
try:
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
reference=chunk_reference,
|
||||
)
|
||||
except Exception as e:
|
||||
err_text = str(e)
|
||||
if (
|
||||
chunk_reference is not None
|
||||
and "error code: 50035" in err_text
|
||||
and "Cannot reply to a system message" in err_text
|
||||
):
|
||||
logger.warning(
|
||||
"[%s] Reply target %s is a Discord system message; retrying send without reply reference",
|
||||
self.name,
|
||||
reply_to,
|
||||
)
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
reference=None,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
message_ids.append(str(msg.id))
|
||||
|
||||
return SendResult(
|
||||
@@ -649,6 +705,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local file as a Discord attachment."""
|
||||
if not self._client:
|
||||
@@ -660,7 +717,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
filename = os.path.basename(file_path)
|
||||
filename = file_name or os.path.basename(file_path)
|
||||
with open(file_path, "rb") as fh:
|
||||
file = discord.File(fh, filename=filename)
|
||||
msg = await channel.send(content=caption if caption else None, file=file)
|
||||
@@ -674,13 +731,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
) -> SendResult:
|
||||
"""Play auto-TTS audio.
|
||||
|
||||
When the bot is in a voice channel for this chat's guild, skip the
|
||||
file attachment — the gateway runner plays audio in the VC instead.
|
||||
When the bot is in a voice channel for this chat's guild, play
|
||||
directly in the VC instead of sending as a file attachment.
|
||||
"""
|
||||
for gid, text_ch_id in self._voice_text_channels.items():
|
||||
if str(text_ch_id) == str(chat_id) and self.is_in_voice_channel(gid):
|
||||
logger.debug("[%s] Skipping play_tts for %s — VC playback handled by runner", self.name, chat_id)
|
||||
return SendResult(success=True)
|
||||
logger.info("[%s] Playing TTS in voice channel (guild=%d)", self.name, gid)
|
||||
success = await self.play_in_voice_channel(gid, audio_path)
|
||||
return SendResult(success=success)
|
||||
return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs)
|
||||
|
||||
async def send_voice(
|
||||
@@ -784,7 +842,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
# Start voice receiver (Phase 2: listen to users)
|
||||
try:
|
||||
receiver = VoiceReceiver(vc)
|
||||
receiver = VoiceReceiver(vc, allowed_user_ids=self._allowed_user_ids)
|
||||
receiver.start()
|
||||
self._voice_receivers[guild_id] = receiver
|
||||
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
|
||||
@@ -980,14 +1038,32 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Voice listening (Phase 2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# UDP keepalive interval in seconds — prevents Discord from dropping
|
||||
# the UDP route after ~60s of silence.
|
||||
_KEEPALIVE_INTERVAL = 15
|
||||
|
||||
async def _voice_listen_loop(self, guild_id: int):
|
||||
"""Periodically check for completed utterances and process them."""
|
||||
receiver = self._voice_receivers.get(guild_id)
|
||||
if not receiver:
|
||||
return
|
||||
last_keepalive = time.monotonic()
|
||||
try:
|
||||
while receiver._running:
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Send periodic UDP keepalive to prevent Discord from
|
||||
# dropping the UDP session after ~60s of silence.
|
||||
now = time.monotonic()
|
||||
if now - last_keepalive >= self._KEEPALIVE_INTERVAL:
|
||||
last_keepalive = now
|
||||
try:
|
||||
vc = self._voice_clients.get(guild_id)
|
||||
if vc and vc.is_connected():
|
||||
vc._connection.send_packet(b'\xf8\xff\xfe')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
completed = receiver.check_silence()
|
||||
for user_id, pcm_data in completed:
|
||||
if not self._is_allowed_user(str(user_id)):
|
||||
@@ -1121,6 +1197,41 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
exc_info=True,
|
||||
)
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local video file natively as a Discord attachment."""
|
||||
try:
|
||||
return await self._send_file_attachment(chat_id, video_path, caption)
|
||||
except FileNotFoundError:
|
||||
return SendResult(success=False, error=f"Video file not found: {video_path}")
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to send local video, falling back to base adapter: %s", self.name, e, exc_info=True)
|
||||
return await super().send_video(chat_id, video_path, caption, reply_to, metadata=metadata)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send an arbitrary file natively as a Discord attachment."""
|
||||
try:
|
||||
return await self._send_file_attachment(chat_id, file_path, caption, file_name=file_name)
|
||||
except FileNotFoundError:
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to send document, falling back to base adapter: %s", self.name, e, exc_info=True)
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to, metadata=metadata)
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""Send typing indicator."""
|
||||
@@ -1690,14 +1801,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
# UNLESS the channel is in the free-response list.
|
||||
# UNLESS the channel is in the free-response list or the message is
|
||||
# in a thread where the bot has already participated.
|
||||
#
|
||||
# Config:
|
||||
# DISCORD_FREE_RESPONSE_CHANNELS: Comma-separated channel IDs where the
|
||||
# bot responds to every message without needing a mention.
|
||||
# DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement
|
||||
# globally (all channels become free-response). Default: "true".
|
||||
# Can also be set via discord.require_mention in config.yaml.
|
||||
# Config (all settable via discord.* in config.yaml):
|
||||
# discord.require_mention: Require @mention in server channels (default: true)
|
||||
# discord.free_response_channels: Channel IDs where bot responds without mention
|
||||
# discord.auto_thread: Auto-create thread on @mention in channels (default: true)
|
||||
|
||||
thread_id = None
|
||||
parent_channel_id = None
|
||||
@@ -1716,7 +1826,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
is_free_channel = bool(channel_ids & free_channels)
|
||||
|
||||
if require_mention and not is_free_channel:
|
||||
# Skip the mention check if the message is in a thread where
|
||||
# the bot has previously participated (auto-created or replied in).
|
||||
in_bot_thread = is_thread and thread_id in self._bot_participated_threads
|
||||
|
||||
if require_mention and not is_free_channel and not in_bot_thread:
|
||||
if self._client.user not in message.mentions:
|
||||
return
|
||||
|
||||
@@ -1725,17 +1839,18 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip()
|
||||
|
||||
# Auto-thread: when enabled, automatically create a thread for every
|
||||
# new message in a text channel so each conversation is isolated.
|
||||
# @mention in a text channel so each conversation is isolated (like Slack).
|
||||
# Messages already inside threads or DMs are unaffected.
|
||||
auto_threaded_channel = None
|
||||
if not is_thread and not isinstance(message.channel, discord.DMChannel):
|
||||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "").lower() in ("true", "1", "yes")
|
||||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in ("true", "1", "yes")
|
||||
if auto_thread:
|
||||
thread = await self._auto_create_thread(message)
|
||||
if thread:
|
||||
is_thread = True
|
||||
thread_id = str(thread.id)
|
||||
auto_threaded_channel = thread
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
@@ -1835,7 +1950,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
|
||||
timestamp=message.created_at,
|
||||
)
|
||||
|
||||
|
||||
# Track thread participation so the bot won't require @mention for
|
||||
# follow-up messages in threads it has already engaged in.
|
||||
if thread_id:
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
|
||||
+203
-11
@@ -105,12 +105,48 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
|
||||
# Telegram message limits
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
MEDIA_GROUP_WAIT_SECONDS = 0.8
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.TELEGRAM)
|
||||
self._app: Optional[Application] = None
|
||||
self._bot: Optional[Bot] = None
|
||||
|
||||
# Buffer rapid/album photo updates so Telegram image bursts are handled
|
||||
# as a single MessageEvent instead of self-interrupting multiple turns.
|
||||
self._media_batch_delay_seconds = float(os.getenv("HERMES_TELEGRAM_MEDIA_BATCH_DELAY_SECONDS", "0.8"))
|
||||
self._pending_photo_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_photo_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._media_group_events: Dict[str, MessageEvent] = {}
|
||||
self._media_group_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
self._polling_error_task: Optional[asyncio.Task] = None
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_polling_conflict(error: Exception) -> bool:
|
||||
text = str(error).lower()
|
||||
return (
|
||||
error.__class__.__name__.lower() == "conflict"
|
||||
or "terminated by other getupdates request" in text
|
||||
or "another bot instance is running" in text
|
||||
)
|
||||
|
||||
async def _handle_polling_conflict(self, error: Exception) -> None:
|
||||
if self.has_fatal_error and self.fatal_error_code == "telegram_polling_conflict":
|
||||
return
|
||||
message = (
|
||||
"Another Telegram bot poller is already using this token. "
|
||||
"Hermes stopped Telegram polling to avoid endless retry spam. "
|
||||
"Make sure only one gateway instance is running for this bot token."
|
||||
)
|
||||
logger.error("[%s] %s Original error: %s", self.name, message, error)
|
||||
self._set_fatal_error("telegram_polling_conflict", message, retryable=False)
|
||||
try:
|
||||
if self._app and self._app.updater:
|
||||
await self._app.updater.stop()
|
||||
except Exception as stop_error:
|
||||
logger.warning("[%s] Failed stopping Telegram polling after conflict: %s", self.name, stop_error, exc_info=True)
|
||||
await self._notify_fatal_error()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Telegram and start polling for updates."""
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
@@ -125,6 +161,25 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._token_lock_identity = self.config.token
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"telegram-bot-token",
|
||||
self._token_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Telegram bot token"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Telegram poller."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("telegram_token_lock", message, retryable=False)
|
||||
return False
|
||||
|
||||
# Build the application
|
||||
self._app = Application.builder().token(self.config.token).build()
|
||||
self._bot = self._app.bot
|
||||
@@ -150,9 +205,20 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _polling_error_callback(error: Exception) -> None:
|
||||
if not self._looks_like_polling_conflict(error):
|
||||
logger.error("[%s] Telegram polling error: %s", self.name, error, exc_info=True)
|
||||
return
|
||||
if self._polling_error_task and not self._polling_error_task.done():
|
||||
return
|
||||
self._polling_error_task = loop.create_task(self._handle_polling_conflict(error))
|
||||
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
drop_pending_updates=True,
|
||||
error_callback=_polling_error_callback,
|
||||
)
|
||||
|
||||
# Register bot commands so Telegram shows a hint menu when users type /
|
||||
@@ -188,29 +254,59 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
self._running = True
|
||||
self._mark_connected()
|
||||
logger.info("[%s] Connected and polling for Telegram updates", self.name)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True)
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop polling and disconnect."""
|
||||
"""Stop polling, cancel pending album flushes, and disconnect."""
|
||||
pending_media_group_tasks = list(self._media_group_tasks.values())
|
||||
for task in pending_media_group_tasks:
|
||||
task.cancel()
|
||||
if pending_media_group_tasks:
|
||||
await asyncio.gather(*pending_media_group_tasks, return_exceptions=True)
|
||||
self._media_group_tasks.clear()
|
||||
self._media_group_events.clear()
|
||||
|
||||
if self._app:
|
||||
try:
|
||||
await self._app.updater.stop()
|
||||
await self._app.stop()
|
||||
# Only stop the updater if it's running
|
||||
if self._app.updater and self._app.updater.running:
|
||||
await self._app.updater.stop()
|
||||
if self._app.running:
|
||||
await self._app.stop()
|
||||
await self._app.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True)
|
||||
|
||||
self._running = False
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True)
|
||||
|
||||
for task in self._pending_photo_batch_tasks.values():
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
self._pending_photo_batch_tasks.clear()
|
||||
self._pending_photo_batches.clear()
|
||||
|
||||
self._mark_disconnected()
|
||||
self._app = None
|
||||
self._bot = None
|
||||
self._token_lock_identity = None
|
||||
logger.info("[%s] Disconnected from Telegram", self.name)
|
||||
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -722,6 +818,49 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
event.text = "\n".join(parts)
|
||||
await self.handle_message(event)
|
||||
|
||||
def _photo_batch_key(self, event: MessageEvent, msg: Message) -> str:
|
||||
"""Return a batching key for Telegram photos/albums."""
|
||||
from gateway.session import build_session_key
|
||||
session_key = build_session_key(event.source)
|
||||
media_group_id = getattr(msg, "media_group_id", None)
|
||||
if media_group_id:
|
||||
return f"{session_key}:album:{media_group_id}"
|
||||
return f"{session_key}:photo-burst"
|
||||
|
||||
async def _flush_photo_batch(self, batch_key: str) -> None:
|
||||
"""Send a buffered photo burst/album as a single MessageEvent."""
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
await asyncio.sleep(self._media_batch_delay_seconds)
|
||||
event = self._pending_photo_batches.pop(batch_key, None)
|
||||
if not event:
|
||||
return
|
||||
logger.info("[Telegram] Flushing photo batch %s with %d image(s)", batch_key, len(event.media_urls))
|
||||
await self.handle_message(event)
|
||||
finally:
|
||||
if self._pending_photo_batch_tasks.get(batch_key) is current_task:
|
||||
self._pending_photo_batch_tasks.pop(batch_key, None)
|
||||
|
||||
def _enqueue_photo_event(self, batch_key: str, event: MessageEvent) -> None:
|
||||
"""Merge photo events into a pending batch and schedule flush."""
|
||||
existing = self._pending_photo_batches.get(batch_key)
|
||||
if existing is None:
|
||||
self._pending_photo_batches[batch_key] = event
|
||||
else:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
|
||||
prior_task = self._pending_photo_batch_tasks.get(batch_key)
|
||||
if prior_task and not prior_task.done():
|
||||
prior_task.cancel()
|
||||
|
||||
self._pending_photo_batch_tasks[batch_key] = asyncio.create_task(self._flush_photo_batch(batch_key))
|
||||
|
||||
async def _handle_media_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming media messages, downloading images to local cache."""
|
||||
if not update.message:
|
||||
@@ -773,14 +912,22 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
if file_obj.file_path.lower().endswith(candidate):
|
||||
ext = candidate
|
||||
break
|
||||
# Save to cache and populate media_urls with the local path
|
||||
# Save to local cache (for vision tool access)
|
||||
cached_path = cache_image_from_bytes(bytes(image_bytes), ext=ext)
|
||||
event.media_urls = [cached_path]
|
||||
event.media_types = [f"image/{ext.lstrip('.')}"]
|
||||
event.media_types = [f"image/{ext.lstrip('.')}" ]
|
||||
logger.info("[Telegram] Cached user photo at %s", cached_path)
|
||||
media_group_id = getattr(msg, "media_group_id", None)
|
||||
if media_group_id:
|
||||
await self._queue_media_group_event(str(media_group_id), event)
|
||||
else:
|
||||
batch_key = self._photo_batch_key(event, msg)
|
||||
self._enqueue_photo_event(batch_key, event)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("[Telegram] Failed to cache photo: %s", e, exc_info=True)
|
||||
|
||||
|
||||
# Download voice/audio messages to cache for STT transcription
|
||||
if msg.voice:
|
||||
try:
|
||||
@@ -872,8 +1019,53 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
logger.warning("[Telegram] Failed to cache document: %s", e, exc_info=True)
|
||||
|
||||
media_group_id = getattr(msg, "media_group_id", None)
|
||||
if media_group_id:
|
||||
await self._queue_media_group_event(str(media_group_id), event)
|
||||
return
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _queue_media_group_event(self, media_group_id: str, event: MessageEvent) -> None:
|
||||
"""Buffer Telegram media-group items so albums arrive as one logical event.
|
||||
|
||||
Telegram delivers albums as multiple updates with a shared media_group_id.
|
||||
If we forward each item immediately, the gateway thinks the second image is a
|
||||
new user message and interrupts the first. We debounce briefly and merge the
|
||||
attachments into a single MessageEvent.
|
||||
"""
|
||||
existing = self._media_group_events.get(media_group_id)
|
||||
if existing is None:
|
||||
self._media_group_events[media_group_id] = event
|
||||
else:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if existing.text:
|
||||
if event.text not in existing.text.split("\n\n"):
|
||||
existing.text = f"{existing.text}\n\n{event.text}"
|
||||
else:
|
||||
existing.text = event.text
|
||||
|
||||
prior_task = self._media_group_tasks.get(media_group_id)
|
||||
if prior_task:
|
||||
prior_task.cancel()
|
||||
|
||||
self._media_group_tasks[media_group_id] = asyncio.create_task(
|
||||
self._flush_media_group_event(media_group_id)
|
||||
)
|
||||
|
||||
async def _flush_media_group_event(self, media_group_id: str) -> None:
|
||||
try:
|
||||
await asyncio.sleep(self.MEDIA_GROUP_WAIT_SECONDS)
|
||||
event = self._media_group_events.pop(media_group_id, None)
|
||||
if event is not None:
|
||||
await self.handle_message(event)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
self._media_group_tasks.pop(media_group_id, None)
|
||||
|
||||
async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None:
|
||||
"""
|
||||
Describe a Telegram sticker via vision analysis, with caching.
|
||||
|
||||
+238
-46
@@ -35,16 +35,12 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
# Resolve Hermes home directory (respects HERMES_HOME override)
|
||||
_hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
# Load environment variables from ~/.hermes/.env first
|
||||
from dotenv import load_dotenv
|
||||
# Load environment variables from ~/.hermes/.env first.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
from dotenv import load_dotenv # backward-compat for tests that monkeypatch this symbol
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
_env_path = _hermes_home / '.env'
|
||||
if _env_path.exists():
|
||||
try:
|
||||
load_dotenv(_env_path, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(_env_path, encoding="latin-1")
|
||||
# Also try project .env as fallback
|
||||
load_dotenv()
|
||||
load_hermes_dotenv(hermes_home=_hermes_home, project_env=Path(__file__).resolve().parents[1] / '.env')
|
||||
|
||||
# Bridge config.yaml values into the environment so os.getenv() picks them up.
|
||||
# config.yaml is authoritative for terminal settings — overrides .env.
|
||||
@@ -68,6 +64,7 @@ if _config_path.exists():
|
||||
"timeout": "TERMINAL_TIMEOUT",
|
||||
"lifetime_seconds": "TERMINAL_LIFETIME_SECONDS",
|
||||
"docker_image": "TERMINAL_DOCKER_IMAGE",
|
||||
"docker_forward_env": "TERMINAL_DOCKER_FORWARD_ENV",
|
||||
"singularity_image": "TERMINAL_SINGULARITY_IMAGE",
|
||||
"modal_image": "TERMINAL_MODAL_IMAGE",
|
||||
"daytona_image": "TERMINAL_DAYTONA_IMAGE",
|
||||
@@ -100,24 +97,40 @@ if _config_path.exists():
|
||||
for _cfg_key, _env_var in _compression_env_map.items():
|
||||
if _cfg_key in _compression_cfg:
|
||||
os.environ[_env_var] = str(_compression_cfg[_cfg_key])
|
||||
# Auxiliary model overrides (vision, web_extract).
|
||||
# Each task has provider + model; bridge non-default values to env vars.
|
||||
# Auxiliary model/direct-endpoint overrides (vision, web_extract).
|
||||
# Each task has provider/model/base_url/api_key; bridge non-default values to env vars.
|
||||
_auxiliary_cfg = _cfg.get("auxiliary", {})
|
||||
if _auxiliary_cfg and isinstance(_auxiliary_cfg, dict):
|
||||
_aux_task_env = {
|
||||
"vision": ("AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL"),
|
||||
"web_extract": ("AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL"),
|
||||
"vision": {
|
||||
"provider": "AUXILIARY_VISION_PROVIDER",
|
||||
"model": "AUXILIARY_VISION_MODEL",
|
||||
"base_url": "AUXILIARY_VISION_BASE_URL",
|
||||
"api_key": "AUXILIARY_VISION_API_KEY",
|
||||
},
|
||||
"web_extract": {
|
||||
"provider": "AUXILIARY_WEB_EXTRACT_PROVIDER",
|
||||
"model": "AUXILIARY_WEB_EXTRACT_MODEL",
|
||||
"base_url": "AUXILIARY_WEB_EXTRACT_BASE_URL",
|
||||
"api_key": "AUXILIARY_WEB_EXTRACT_API_KEY",
|
||||
},
|
||||
}
|
||||
for _task_key, (_prov_env, _model_env) in _aux_task_env.items():
|
||||
for _task_key, _env_map in _aux_task_env.items():
|
||||
_task_cfg = _auxiliary_cfg.get(_task_key, {})
|
||||
if not isinstance(_task_cfg, dict):
|
||||
continue
|
||||
_prov = str(_task_cfg.get("provider", "")).strip()
|
||||
_model = str(_task_cfg.get("model", "")).strip()
|
||||
_base_url = str(_task_cfg.get("base_url", "")).strip()
|
||||
_api_key = str(_task_cfg.get("api_key", "")).strip()
|
||||
if _prov and _prov != "auto":
|
||||
os.environ[_prov_env] = _prov
|
||||
os.environ[_env_map["provider"]] = _prov
|
||||
if _model:
|
||||
os.environ[_model_env] = _model
|
||||
os.environ[_env_map["model"]] = _model
|
||||
if _base_url:
|
||||
os.environ[_env_map["base_url"]] = _base_url
|
||||
if _api_key:
|
||||
os.environ[_env_map["api_key"]] = _api_key
|
||||
_agent_cfg = _cfg.get("agent", {})
|
||||
if _agent_cfg and isinstance(_agent_cfg, dict):
|
||||
if "max_turns" in _agent_cfg:
|
||||
@@ -215,6 +228,33 @@ def _resolve_gateway_model() -> str:
|
||||
return model
|
||||
|
||||
|
||||
def _resolve_hermes_bin() -> Optional[list[str]]:
|
||||
"""Resolve the Hermes update command as argv parts.
|
||||
|
||||
Tries in order:
|
||||
1. ``shutil.which("hermes")`` — standard PATH lookup
|
||||
2. ``sys.executable -m hermes_cli.main`` — fallback when Hermes is running
|
||||
from a venv/module invocation and the ``hermes`` shim is not on PATH
|
||||
|
||||
Returns argv parts ready for quoting/joining, or ``None`` if neither works.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
hermes_bin = shutil.which("hermes")
|
||||
if hermes_bin:
|
||||
return [hermes_bin]
|
||||
|
||||
try:
|
||||
import importlib.util
|
||||
|
||||
if importlib.util.find_spec("hermes_cli") is not None:
|
||||
return [sys.executable, "-m", "hermes_cli.main"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class GatewayRunner:
|
||||
"""
|
||||
Main gateway controller.
|
||||
@@ -245,6 +285,8 @@ class GatewayRunner:
|
||||
self.delivery_router = DeliveryRouter(self.config)
|
||||
self._running = False
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._exit_cleanly = False
|
||||
self._exit_reason: Optional[str] = None
|
||||
|
||||
# Track running agents per session for interrupt support
|
||||
# Key: session_key, Value: AIAgent instance
|
||||
@@ -463,6 +505,41 @@ class GatewayRunner:
|
||||
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id)
|
||||
|
||||
@property
|
||||
def should_exit_cleanly(self) -> bool:
|
||||
return self._exit_cleanly
|
||||
|
||||
@property
|
||||
def exit_reason(self) -> Optional[str]:
|
||||
return self._exit_reason
|
||||
|
||||
async def _handle_adapter_fatal_error(self, adapter: BasePlatformAdapter) -> None:
|
||||
"""React to a non-retryable adapter failure after startup."""
|
||||
logger.error(
|
||||
"Fatal %s adapter error (%s): %s",
|
||||
adapter.platform.value,
|
||||
adapter.fatal_error_code or "unknown",
|
||||
adapter.fatal_error_message or "unknown error",
|
||||
)
|
||||
|
||||
existing = self.adapters.get(adapter.platform)
|
||||
if existing is adapter:
|
||||
try:
|
||||
await adapter.disconnect()
|
||||
finally:
|
||||
self.adapters.pop(adapter.platform, None)
|
||||
self.delivery_router.adapters = self.adapters
|
||||
|
||||
if not self.adapters:
|
||||
self._exit_reason = adapter.fatal_error_message or "All messaging adapters disconnected"
|
||||
logger.error("No connected messaging platforms remain. Shutting down gateway cleanly.")
|
||||
await self.stop()
|
||||
|
||||
def _request_clean_exit(self, reason: str) -> None:
|
||||
self._exit_cleanly = True
|
||||
self._exit_reason = reason
|
||||
self._shutdown_event.set()
|
||||
|
||||
@staticmethod
|
||||
def _load_prefill_messages() -> List[Dict[str, Any]]:
|
||||
@@ -647,6 +724,11 @@ class GatewayRunner:
|
||||
"""
|
||||
logger.info("Starting Hermes Gateway...")
|
||||
logger.info("Session storage: %s", self.config.sessions_dir)
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(gateway_state="starting", exit_reason=None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Warn if no user allowlists are configured and open access is not opted in
|
||||
_any_allowlist = any(
|
||||
@@ -676,6 +758,7 @@ class GatewayRunner:
|
||||
logger.warning("Process checkpoint recovery: %s", e)
|
||||
|
||||
connected_count = 0
|
||||
startup_nonretryable_errors: list[str] = []
|
||||
|
||||
# Initialize and connect each configured platform
|
||||
for platform, platform_config in self.config.platforms.items():
|
||||
@@ -687,8 +770,9 @@ class GatewayRunner:
|
||||
logger.warning("No adapter available for %s", platform.value)
|
||||
continue
|
||||
|
||||
# Set up message handler
|
||||
# Set up message + fatal error handlers
|
||||
adapter.set_message_handler(self._handle_message)
|
||||
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
|
||||
|
||||
# Try to connect
|
||||
logger.info("Connecting to %s...", platform.value)
|
||||
@@ -701,10 +785,24 @@ class GatewayRunner:
|
||||
logger.info("✓ %s connected", platform.value)
|
||||
else:
|
||||
logger.warning("✗ %s failed to connect", platform.value)
|
||||
if adapter.has_fatal_error and not adapter.fatal_error_retryable:
|
||||
startup_nonretryable_errors.append(
|
||||
f"{platform.value}: {adapter.fatal_error_message}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("✗ %s error: %s", platform.value, e)
|
||||
|
||||
if connected_count == 0:
|
||||
if startup_nonretryable_errors:
|
||||
reason = "; ".join(startup_nonretryable_errors)
|
||||
logger.error("Gateway hit a non-retryable startup conflict: %s", reason)
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(gateway_state="startup_failed", exit_reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
self._request_clean_exit(reason)
|
||||
return True
|
||||
logger.warning("No messaging platforms connected.")
|
||||
logger.info("Gateway will continue running for cron job execution.")
|
||||
|
||||
@@ -712,6 +810,11 @@ class GatewayRunner:
|
||||
self.delivery_router.adapters = self.adapters
|
||||
|
||||
self._running = True
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(gateway_state="running", exit_reason=None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Emit gateway:startup hook
|
||||
hook_count = len(self.hooks.loaded_hooks)
|
||||
@@ -794,8 +897,19 @@ class GatewayRunner:
|
||||
"""Stop the gateway and disconnect all adapters."""
|
||||
logger.info("Stopping gateway...")
|
||||
self._running = False
|
||||
|
||||
|
||||
for session_key, agent in list(self._running_agents.items()):
|
||||
try:
|
||||
agent.interrupt("Gateway shutting down")
|
||||
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
|
||||
except Exception as e:
|
||||
logger.debug("Failed interrupting agent during shutdown: %s", e)
|
||||
|
||||
for platform, adapter in list(self.adapters.items()):
|
||||
try:
|
||||
await adapter.cancel_background_tasks()
|
||||
except Exception as e:
|
||||
logger.debug("✗ %s background-task cancel error: %s", platform.value, e)
|
||||
try:
|
||||
await adapter.disconnect()
|
||||
logger.info("✓ %s disconnected", platform.value)
|
||||
@@ -803,11 +917,18 @@ class GatewayRunner:
|
||||
logger.error("✗ %s disconnect error: %s", platform.value, e)
|
||||
|
||||
self.adapters.clear()
|
||||
self._running_agents.clear()
|
||||
self._pending_messages.clear()
|
||||
self._pending_approvals.clear()
|
||||
self._shutdown_all_gateway_honcho()
|
||||
self._shutdown_event.set()
|
||||
|
||||
from gateway.status import remove_pid_file
|
||||
from gateway.status import remove_pid_file, write_runtime_status
|
||||
remove_pid_file()
|
||||
try:
|
||||
write_runtime_status(gateway_state="stopped", exit_reason=self._exit_reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("Gateway stopped")
|
||||
|
||||
@@ -985,11 +1106,36 @@ class GatewayRunner:
|
||||
)
|
||||
return None
|
||||
|
||||
# PRIORITY: If an agent is already running for this session, interrupt it
|
||||
# immediately. This is before command parsing to minimize latency -- the
|
||||
# user's "stop" message reaches the agent as fast as possible.
|
||||
# PRIORITY handling when an agent is already running for this session.
|
||||
# Default behavior is to interrupt immediately so user text/stop messages
|
||||
# are handled with minimal latency.
|
||||
#
|
||||
# Special case: Telegram/photo bursts often arrive as multiple near-
|
||||
# simultaneous updates. Do NOT interrupt for photo-only follow-ups here;
|
||||
# let the adapter-level batching/queueing logic absorb them.
|
||||
_quick_key = build_session_key(source)
|
||||
if _quick_key in self._running_agents:
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
# Reuse adapter queue semantics so photo bursts merge cleanly.
|
||||
if _quick_key in adapter._pending_messages:
|
||||
existing = adapter._pending_messages[_quick_key]
|
||||
if getattr(existing, "message_type", None) == MessageType.PHOTO:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
if not existing.text:
|
||||
existing.text = event.text
|
||||
elif event.text not in existing.text:
|
||||
existing.text = f"{existing.text}\n\n{event.text}".strip()
|
||||
else:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
else:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
return None
|
||||
|
||||
running_agent = self._running_agents[_quick_key]
|
||||
logger.debug("PRIORITY interrupt for session %s", _quick_key[:20])
|
||||
running_agent.interrupt(event.text)
|
||||
@@ -1004,7 +1150,7 @@ class GatewayRunner:
|
||||
|
||||
# Emit command:* hook for any recognized slash command
|
||||
_known_commands = {"new", "reset", "help", "status", "stop", "model", "reasoning",
|
||||
"personality", "retry", "undo", "sethome", "set-home",
|
||||
"personality", "plan", "retry", "undo", "sethome", "set-home",
|
||||
"compress", "usage", "insights", "reload-mcp", "reload_mcp",
|
||||
"update", "title", "resume", "provider", "rollback",
|
||||
"background", "reasoning", "voice"}
|
||||
@@ -1039,6 +1185,28 @@ class GatewayRunner:
|
||||
|
||||
if command == "personality":
|
||||
return await self._handle_personality_command(event)
|
||||
|
||||
if command == "plan":
|
||||
try:
|
||||
from agent.skill_commands import build_plan_path, build_skill_invocation_message
|
||||
|
||||
user_instruction = event.get_command_args().strip()
|
||||
plan_path = build_plan_path(user_instruction)
|
||||
event.text = build_skill_invocation_message(
|
||||
"/plan",
|
||||
user_instruction,
|
||||
task_id=_quick_key,
|
||||
runtime_note=(
|
||||
"Save the markdown plan with write_file to this exact relative path "
|
||||
f"inside the active workspace/backend cwd: {plan_path}"
|
||||
),
|
||||
)
|
||||
if not event.text:
|
||||
return "Failed to load the bundled /plan skill."
|
||||
command = None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to prepare /plan command")
|
||||
return f"Failed to enter plan mode: {e}"
|
||||
|
||||
if command == "retry":
|
||||
return await self._handle_retry_command(event)
|
||||
@@ -2264,6 +2432,13 @@ class GatewayRunner:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to join voice channel: %s", e)
|
||||
adapter._voice_input_callback = None
|
||||
err_lower = str(e).lower()
|
||||
if "pynacl" in err_lower or "nacl" in err_lower or "davey" in err_lower:
|
||||
return (
|
||||
"Voice dependencies are missing (PyNaCl / davey). "
|
||||
"Install or reinstall Hermes with the messaging extra, e.g. "
|
||||
"`pip install hermes-agent[messaging]`."
|
||||
)
|
||||
return f"Failed to join voice channel: {e}"
|
||||
|
||||
if success:
|
||||
@@ -2404,18 +2579,9 @@ class GatewayRunner:
|
||||
if has_agent_tts:
|
||||
return False
|
||||
|
||||
# Dedup: base adapter auto-TTS already handles voice input.
|
||||
# Exception: Discord voice channel — play_tts override is a no-op,
|
||||
# so the runner must handle VC playback.
|
||||
skip_double = is_voice_input
|
||||
if skip_double:
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if (guild_id and adapter
|
||||
and hasattr(adapter, "is_in_voice_channel")
|
||||
and adapter.is_in_voice_channel(guild_id)):
|
||||
skip_double = False
|
||||
if skip_double:
|
||||
# Dedup: base adapter auto-TTS already handles voice input
|
||||
# (play_tts plays in VC when connected, so runner can skip).
|
||||
if is_voice_input:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -3155,9 +3321,14 @@ class GatewayRunner:
|
||||
if not git_dir.exists():
|
||||
return "✗ Not a git repository — cannot update."
|
||||
|
||||
hermes_bin = shutil.which("hermes")
|
||||
if not hermes_bin:
|
||||
return "✗ `hermes` command not found on PATH."
|
||||
hermes_cmd = _resolve_hermes_bin()
|
||||
if not hermes_cmd:
|
||||
return (
|
||||
"✗ Could not locate the `hermes` command. "
|
||||
"Hermes is running, but the update command could not find the "
|
||||
"executable on PATH or via the current Python interpreter. "
|
||||
"Try running `hermes update` manually in your terminal."
|
||||
)
|
||||
|
||||
pending_path = _hermes_home / ".update_pending.json"
|
||||
output_path = _hermes_home / ".update_output.txt"
|
||||
@@ -3173,8 +3344,9 @@ class GatewayRunner:
|
||||
|
||||
# Spawn `hermes update` in a separate cgroup so it survives gateway
|
||||
# restart. systemd-run --user --scope creates a transient scope unit.
|
||||
hermes_cmd_str = " ".join(shlex.quote(part) for part in hermes_cmd)
|
||||
update_cmd = (
|
||||
f"{shlex.quote(hermes_bin)} update > {shlex.quote(str(output_path))} 2>&1; "
|
||||
f"{hermes_cmd_str} update > {shlex.quote(str(output_path))} 2>&1; "
|
||||
f"status=$?; printf '%s' \"$status\" > {shlex.quote(str(exit_code_path))}"
|
||||
)
|
||||
try:
|
||||
@@ -3331,10 +3503,12 @@ class GatewayRunner:
|
||||
os.environ["HERMES_SESSION_CHAT_ID"] = context.source.chat_id
|
||||
if context.source.chat_name:
|
||||
os.environ["HERMES_SESSION_CHAT_NAME"] = context.source.chat_name
|
||||
if context.source.thread_id:
|
||||
os.environ["HERMES_SESSION_THREAD_ID"] = str(context.source.thread_id)
|
||||
|
||||
def _clear_session_env(self) -> None:
|
||||
"""Clear session environment variables."""
|
||||
for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME"]:
|
||||
for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME", "HERMES_SESSION_THREAD_ID"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
@@ -3352,9 +3526,13 @@ class GatewayRunner:
|
||||
1. Immediately understand what the user sent (no extra tool call).
|
||||
2. Re-examine the image with vision_analyze if it needs more detail.
|
||||
|
||||
Athabasca persistence should happen through Athabasca's own POST
|
||||
/api/uploads flow, using the returned asset.publicUrl rather than local
|
||||
cache paths.
|
||||
|
||||
Args:
|
||||
user_text: The user's original caption / message text.
|
||||
image_paths: List of local file paths to cached images.
|
||||
user_text: The user's original caption / message text.
|
||||
image_paths: List of local file paths to cached images.
|
||||
|
||||
Returns:
|
||||
The enriched message string with vision descriptions prepended.
|
||||
@@ -3379,10 +3557,16 @@ class GatewayRunner:
|
||||
result = _json.loads(result_json)
|
||||
if result.get("success"):
|
||||
description = result.get("analysis", "")
|
||||
athabasca_note = (
|
||||
"\n[If this image needs to persist in Athabasca state, upload the cached file "
|
||||
"through Athabasca POST /api/uploads and use the returned asset.publicUrl. "
|
||||
"Do not store the local cache path as the canonical imageUrl.]"
|
||||
)
|
||||
enriched_parts.append(
|
||||
f"[The user sent an image~ Here's what I can see:\n{description}]\n"
|
||||
f"[If you need a closer look, use vision_analyze with "
|
||||
f"image_url: {path} ~]"
|
||||
f"{athabasca_note}"
|
||||
)
|
||||
else:
|
||||
enriched_parts.append(
|
||||
@@ -3412,7 +3596,7 @@ class GatewayRunner:
|
||||
audio_paths: List[str],
|
||||
) -> str:
|
||||
"""
|
||||
Auto-transcribe user voice/audio messages using OpenAI Whisper API
|
||||
Auto-transcribe user voice/audio messages using the configured STT provider
|
||||
and prepend the transcript to the message text.
|
||||
|
||||
Args:
|
||||
@@ -3422,6 +3606,12 @@ class GatewayRunner:
|
||||
Returns:
|
||||
The enriched message string with transcriptions prepended.
|
||||
"""
|
||||
if not getattr(self.config, "stt_enabled", True):
|
||||
disabled_note = "[The user sent voice message(s), but transcription is disabled in config.]"
|
||||
if user_text:
|
||||
return f"{disabled_note}\n\n{user_text}"
|
||||
return disabled_note
|
||||
|
||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||
import asyncio
|
||||
|
||||
@@ -3694,9 +3884,7 @@ class GatewayRunner:
|
||||
"memory": "🧠",
|
||||
"session_search": "🔍",
|
||||
"send_message": "📨",
|
||||
"schedule_cronjob": "⏰",
|
||||
"list_cronjobs": "⏰",
|
||||
"remove_cronjob": "⏰",
|
||||
"cronjob": "⏰",
|
||||
"execute_code": "🐍",
|
||||
"delegate_task": "🔀",
|
||||
"clarify": "❓",
|
||||
@@ -4340,6 +4528,10 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
success = await runner.start()
|
||||
if not success:
|
||||
return False
|
||||
if runner.should_exit_cleanly:
|
||||
if runner.exit_reason:
|
||||
logger.error("Gateway exiting cleanly: %s", runner.exit_reason)
|
||||
return True
|
||||
|
||||
# Write PID file so CLI can detect gateway is running
|
||||
import atexit
|
||||
|
||||
+17
-10
@@ -321,25 +321,32 @@ def build_session_key(source: SessionSource) -> str:
|
||||
This is the single source of truth for session key construction.
|
||||
|
||||
DM rules:
|
||||
- WhatsApp DMs include chat_id (multi-user support).
|
||||
- Other DMs include thread_id when present (e.g. Slack threaded DMs),
|
||||
so each DM thread gets its own session while top-level DMs share one.
|
||||
- Without thread_id or chat_id, all DMs share a single session.
|
||||
- DMs include chat_id when present, so each private conversation is isolated.
|
||||
- thread_id further differentiates threaded DMs within the same DM chat.
|
||||
- Without chat_id, thread_id is used as a best-effort fallback.
|
||||
- Without thread_id or chat_id, DMs share a single session.
|
||||
|
||||
Group/channel rules:
|
||||
- thread_id differentiates threads within a channel.
|
||||
- Without thread_id, all messages in a channel share one session.
|
||||
- chat_id identifies the parent group/channel.
|
||||
- thread_id differentiates threads within that parent chat.
|
||||
- Without identifiers, messages fall back to one session per platform/chat_type.
|
||||
"""
|
||||
platform = source.platform.value
|
||||
if source.chat_type == "dm":
|
||||
if source.chat_id:
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:dm:{source.chat_id}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:dm:{source.chat_id}"
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:dm:{source.thread_id}"
|
||||
if platform == "whatsapp" and source.chat_id:
|
||||
return f"agent:main:{platform}:dm:{source.chat_id}"
|
||||
return f"agent:main:{platform}:dm"
|
||||
if source.chat_id:
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}"
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}"
|
||||
|
||||
|
||||
class SessionStore:
|
||||
|
||||
+187
-4
@@ -11,13 +11,17 @@ that will be useful when we add named profiles (multiple agents running
|
||||
concurrently under distinct configurations).
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
_GATEWAY_KIND = "hermes-gateway"
|
||||
_RUNTIME_STATUS_FILE = "gateway_state.json"
|
||||
_LOCKS_DIRNAME = "gateway-locks"
|
||||
|
||||
|
||||
def _get_pid_path() -> Path:
|
||||
@@ -26,6 +30,32 @@ def _get_pid_path() -> Path:
|
||||
return home / "gateway.pid"
|
||||
|
||||
|
||||
def _get_runtime_status_path() -> Path:
|
||||
"""Return the persisted runtime health/status file path."""
|
||||
return _get_pid_path().with_name(_RUNTIME_STATUS_FILE)
|
||||
|
||||
|
||||
def _get_lock_dir() -> Path:
|
||||
"""Return the machine-local directory for token-scoped gateway locks."""
|
||||
override = os.getenv("HERMES_GATEWAY_LOCK_DIR")
|
||||
if override:
|
||||
return Path(override)
|
||||
state_home = Path(os.getenv("XDG_STATE_HOME", Path.home() / ".local" / "state"))
|
||||
return state_home / "hermes" / _LOCKS_DIRNAME
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _scope_hash(identity: str) -> str:
|
||||
return hashlib.sha256(identity.encode("utf-8")).hexdigest()[:16]
|
||||
|
||||
|
||||
def _get_scope_lock_path(scope: str, identity: str) -> Path:
|
||||
return _get_lock_dir() / f"{scope}-{_scope_hash(identity)}.lock"
|
||||
|
||||
|
||||
def _get_process_start_time(pid: int) -> Optional[int]:
|
||||
"""Return the kernel start time for a process when available."""
|
||||
stat_path = Path(f"/proc/{pid}/stat")
|
||||
@@ -73,6 +103,38 @@ def _build_pid_record() -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _build_runtime_status_record() -> dict[str, Any]:
|
||||
payload = _build_pid_record()
|
||||
payload.update({
|
||||
"gateway_state": "starting",
|
||||
"exit_reason": None,
|
||||
"platforms": {},
|
||||
"updated_at": _utc_now_iso(),
|
||||
})
|
||||
return payload
|
||||
|
||||
|
||||
def _read_json_file(path: Path) -> Optional[dict[str, Any]]:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
raw = path.read_text().strip()
|
||||
except OSError:
|
||||
return None
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return payload if isinstance(payload, dict) else None
|
||||
|
||||
|
||||
def _write_json_file(path: Path, payload: dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(payload))
|
||||
|
||||
|
||||
def _read_pid_record() -> Optional[dict]:
|
||||
pid_path = _get_pid_path()
|
||||
if not pid_path.exists():
|
||||
@@ -99,9 +161,49 @@ def _read_pid_record() -> Optional[dict]:
|
||||
|
||||
def write_pid_file() -> None:
|
||||
"""Write the current process PID and metadata to the gateway PID file."""
|
||||
pid_path = _get_pid_path()
|
||||
pid_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pid_path.write_text(json.dumps(_build_pid_record()))
|
||||
_write_json_file(_get_pid_path(), _build_pid_record())
|
||||
|
||||
|
||||
def write_runtime_status(
|
||||
*,
|
||||
gateway_state: Optional[str] = None,
|
||||
exit_reason: Optional[str] = None,
|
||||
platform: Optional[str] = None,
|
||||
platform_state: Optional[str] = None,
|
||||
error_code: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Persist gateway runtime health information for diagnostics/status."""
|
||||
path = _get_runtime_status_path()
|
||||
payload = _read_json_file(path) or _build_runtime_status_record()
|
||||
payload.setdefault("platforms", {})
|
||||
payload.setdefault("kind", _GATEWAY_KIND)
|
||||
payload.setdefault("pid", os.getpid())
|
||||
payload.setdefault("start_time", _get_process_start_time(os.getpid()))
|
||||
payload["updated_at"] = _utc_now_iso()
|
||||
|
||||
if gateway_state is not None:
|
||||
payload["gateway_state"] = gateway_state
|
||||
if exit_reason is not None:
|
||||
payload["exit_reason"] = exit_reason
|
||||
|
||||
if platform is not None:
|
||||
platform_payload = payload["platforms"].get(platform, {})
|
||||
if platform_state is not None:
|
||||
platform_payload["state"] = platform_state
|
||||
if error_code is not None:
|
||||
platform_payload["error_code"] = error_code
|
||||
if error_message is not None:
|
||||
platform_payload["error_message"] = error_message
|
||||
platform_payload["updated_at"] = _utc_now_iso()
|
||||
payload["platforms"][platform] = platform_payload
|
||||
|
||||
_write_json_file(path, payload)
|
||||
|
||||
|
||||
def read_runtime_status() -> Optional[dict[str, Any]]:
|
||||
"""Read the persisted gateway runtime health/status information."""
|
||||
return _read_json_file(_get_runtime_status_path())
|
||||
|
||||
|
||||
def remove_pid_file() -> None:
|
||||
@@ -112,6 +214,87 @@ def remove_pid_file() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def acquire_scoped_lock(scope: str, identity: str, metadata: Optional[dict[str, Any]] = None) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""Acquire a machine-local lock keyed by scope + identity.
|
||||
|
||||
Used to prevent multiple local gateways from using the same external identity
|
||||
at once (e.g. the same Telegram bot token across different HERMES_HOME dirs).
|
||||
"""
|
||||
lock_path = _get_scope_lock_path(scope, identity)
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
record = {
|
||||
**_build_pid_record(),
|
||||
"scope": scope,
|
||||
"identity_hash": _scope_hash(identity),
|
||||
"metadata": metadata or {},
|
||||
"updated_at": _utc_now_iso(),
|
||||
}
|
||||
|
||||
existing = _read_json_file(lock_path)
|
||||
if existing:
|
||||
try:
|
||||
existing_pid = int(existing["pid"])
|
||||
except (KeyError, TypeError, ValueError):
|
||||
existing_pid = None
|
||||
|
||||
if existing_pid == os.getpid() and existing.get("start_time") == record.get("start_time"):
|
||||
_write_json_file(lock_path, record)
|
||||
return True, existing
|
||||
|
||||
stale = existing_pid is None
|
||||
if not stale:
|
||||
try:
|
||||
os.kill(existing_pid, 0)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
stale = True
|
||||
else:
|
||||
current_start = _get_process_start_time(existing_pid)
|
||||
if (
|
||||
existing.get("start_time") is not None
|
||||
and current_start is not None
|
||||
and current_start != existing.get("start_time")
|
||||
):
|
||||
stale = True
|
||||
if stale:
|
||||
try:
|
||||
lock_path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
return False, existing
|
||||
|
||||
try:
|
||||
fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
||||
except FileExistsError:
|
||||
return False, _read_json_file(lock_path)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as handle:
|
||||
json.dump(record, handle)
|
||||
except Exception:
|
||||
try:
|
||||
lock_path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
return True, None
|
||||
|
||||
|
||||
def release_scoped_lock(scope: str, identity: str) -> None:
|
||||
"""Release a previously-acquired scope lock when owned by this process."""
|
||||
lock_path = _get_scope_lock_path(scope, identity)
|
||||
existing = _read_json_file(lock_path)
|
||||
if not existing:
|
||||
return
|
||||
if existing.get("pid") != os.getpid():
|
||||
return
|
||||
if existing.get("start_time") != _get_process_start_time(os.getpid()):
|
||||
return
|
||||
try:
|
||||
lock_path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def get_running_pid() -> Optional[int]:
|
||||
"""Return the PID of a running gateway instance, or ``None``.
|
||||
|
||||
|
||||
+45
-6
@@ -6,7 +6,9 @@ Pure display functions with no HermesCLI state dependency.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional
|
||||
@@ -143,7 +145,9 @@ def check_for_updates() -> Optional[int]:
|
||||
repo_dir = hermes_home / "hermes-agent"
|
||||
cache_file = hermes_home / ".update_check"
|
||||
|
||||
# Must be a git repo
|
||||
# Must be a git repo — fall back to project root for dev installs
|
||||
if not (repo_dir / ".git").exists():
|
||||
repo_dir = Path(__file__).parent.parent.resolve()
|
||||
if not (repo_dir / ".git").exists():
|
||||
return None
|
||||
|
||||
@@ -190,6 +194,30 @@ def check_for_updates() -> Optional[int]:
|
||||
return behind
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Non-blocking update check
|
||||
# =========================================================================
|
||||
|
||||
_update_result: Optional[int] = None
|
||||
_update_check_done = threading.Event()
|
||||
|
||||
|
||||
def prefetch_update_check():
|
||||
"""Kick off update check in a background daemon thread."""
|
||||
def _run():
|
||||
global _update_result
|
||||
_update_result = check_for_updates()
|
||||
_update_check_done.set()
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
|
||||
|
||||
def get_update_result(timeout: float = 0.5) -> Optional[int]:
|
||||
"""Get result of prefetched check. Returns None if not ready."""
|
||||
_update_check_done.wait(timeout=timeout)
|
||||
return _update_result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Welcome banner
|
||||
# =========================================================================
|
||||
@@ -245,7 +273,15 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
text = _skin_color("banner_text", "#FFF8DC")
|
||||
session_color = _skin_color("session_border", "#8B8682")
|
||||
|
||||
left_lines = ["", HERMES_CADUCEUS, ""]
|
||||
# Use skin's custom caduceus art if provided
|
||||
try:
|
||||
from hermes_cli.skin_engine import get_active_skin
|
||||
_bskin = get_active_skin()
|
||||
_hero = _bskin.banner_hero if hasattr(_bskin, 'banner_hero') and _bskin.banner_hero else HERMES_CADUCEUS
|
||||
except Exception:
|
||||
_bskin = None
|
||||
_hero = HERMES_CADUCEUS
|
||||
left_lines = ["", _hero, ""]
|
||||
model_short = model.split("/")[-1] if "/" in model else model
|
||||
if len(model_short) > 28:
|
||||
model_short = model_short[:25] + "..."
|
||||
@@ -360,9 +396,9 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
summary_parts.append("/help for commands")
|
||||
right_lines.append(f"[dim {dim}]{' · '.join(summary_parts)}[/]")
|
||||
|
||||
# Update check — show if behind origin/main
|
||||
# Update check — use prefetched result if available
|
||||
try:
|
||||
behind = check_for_updates()
|
||||
behind = get_update_result(timeout=0.5)
|
||||
if behind and behind > 0:
|
||||
commits_word = "commit" if behind == 1 else "commits"
|
||||
right_lines.append(
|
||||
@@ -386,6 +422,9 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
)
|
||||
|
||||
console.print()
|
||||
console.print(HERMES_AGENT_LOGO)
|
||||
console.print()
|
||||
term_width = shutil.get_terminal_size().columns
|
||||
if term_width >= 95:
|
||||
_logo = _bskin.banner_logo if _bskin and hasattr(_bskin, 'banner_logo') and _bskin.banner_logo else HERMES_AGENT_LOGO
|
||||
console.print(_logo)
|
||||
console.print()
|
||||
console.print(outer_panel)
|
||||
|
||||
@@ -43,7 +43,7 @@ COMMANDS_BY_CATEGORY = {
|
||||
"/tools": "List available tools",
|
||||
"/toolsets": "List available toolsets",
|
||||
"/skills": "Search, install, inspect, or manage skills from online registries",
|
||||
"/cron": "Manage scheduled tasks (list, add, remove)",
|
||||
"/cron": "Manage scheduled tasks (list, add/create, edit, pause, resume, run, remove)",
|
||||
"/reload-mcp": "Reload MCP servers from config.yaml",
|
||||
},
|
||||
"Info": {
|
||||
|
||||
+32
-6
@@ -106,6 +106,7 @@ DEFAULT_CONFIG = {
|
||||
"cwd": ".", # Use current directory
|
||||
"timeout": 180,
|
||||
"docker_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"docker_forward_env": [],
|
||||
"singularity_image": "docker://nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"modal_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
@@ -150,30 +151,44 @@ DEFAULT_CONFIG = {
|
||||
"vision": {
|
||||
"provider": "auto", # auto | openrouter | nous | codex | custom
|
||||
"model": "", # e.g. "google/gemini-2.5-flash", "gpt-4o"
|
||||
"base_url": "", # direct OpenAI-compatible endpoint (takes precedence over provider)
|
||||
"api_key": "", # API key for base_url (falls back to OPENAI_API_KEY)
|
||||
},
|
||||
"web_extract": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
},
|
||||
"compression": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
},
|
||||
"session_search": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
},
|
||||
"skills_hub": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
},
|
||||
"mcp": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
},
|
||||
"flush_memories": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -205,7 +220,8 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
|
||||
"stt": {
|
||||
"provider": "local", # "local" (free, faster-whisper) | "openai" (Whisper API)
|
||||
"enabled": True,
|
||||
"provider": "local", # "local" (free, faster-whisper) | "groq" | "openai" (Whisper API)
|
||||
"local": {
|
||||
"model": "base", # tiny, base, small, medium, large-v3
|
||||
},
|
||||
@@ -243,6 +259,8 @@ DEFAULT_CONFIG = {
|
||||
"delegation": {
|
||||
"model": "", # e.g. "google/gemini-3-flash-preview" (empty = inherit parent model)
|
||||
"provider": "", # e.g. "openrouter" (empty = inherit parent provider + credentials)
|
||||
"base_url": "", # direct OpenAI-compatible endpoint for subagents
|
||||
"api_key": "", # API key for delegation.base_url (falls back to OPENAI_API_KEY)
|
||||
},
|
||||
|
||||
# Ephemeral prefill messages file — JSON list of {role, content} dicts
|
||||
@@ -263,6 +281,7 @@ DEFAULT_CONFIG = {
|
||||
"discord": {
|
||||
"require_mention": True, # Require @mention to respond in server channels
|
||||
"free_response_channels": "", # Comma-separated channel IDs where bot responds without mention
|
||||
"auto_thread": True, # Auto-create threads on @mention in channels (like Slack)
|
||||
},
|
||||
|
||||
# Permanently allowed dangerous command patterns (added via "always" approval)
|
||||
@@ -284,7 +303,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 7,
|
||||
"_config_version": 9,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
@@ -821,7 +840,7 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
print(f" ✓ Saved {name}")
|
||||
print()
|
||||
else:
|
||||
print(" Set later with: hermes config set KEY VALUE")
|
||||
print(" Set later with: hermes config set <key> <value>")
|
||||
|
||||
# Check for missing config fields
|
||||
missing_config = get_missing_config_fields()
|
||||
@@ -1092,6 +1111,13 @@ def save_anthropic_oauth_token(value: str, save_fn=None):
|
||||
writer("ANTHROPIC_API_KEY", "")
|
||||
|
||||
|
||||
def use_anthropic_claude_code_credentials(save_fn=None):
|
||||
"""Use Claude Code's own credential files instead of persisting env tokens."""
|
||||
writer = save_fn or save_env_value
|
||||
writer("ANTHROPIC_TOKEN", "")
|
||||
writer("ANTHROPIC_API_KEY", "")
|
||||
|
||||
|
||||
def save_anthropic_api_key(value: str, save_fn=None):
|
||||
"""Persist an Anthropic API key and clear the OAuth/setup-token slot."""
|
||||
writer = save_fn or save_env_value
|
||||
@@ -1265,7 +1291,7 @@ def show_config():
|
||||
print()
|
||||
print(color("─" * 60, Colors.DIM))
|
||||
print(color(" hermes config edit # Edit config file", Colors.DIM))
|
||||
print(color(" hermes config set KEY VALUE", Colors.DIM))
|
||||
print(color(" hermes config set <key> <value>", Colors.DIM))
|
||||
print(color(" hermes setup # Run setup wizard", Colors.DIM))
|
||||
print()
|
||||
|
||||
@@ -1391,7 +1417,7 @@ def config_command(args):
|
||||
key = getattr(args, 'key', None)
|
||||
value = getattr(args, 'value', None)
|
||||
if not key or not value:
|
||||
print("Usage: hermes config set KEY VALUE")
|
||||
print("Usage: hermes config set <key> <value>")
|
||||
print()
|
||||
print("Examples:")
|
||||
print(" hermes config set model anthropic/claude-sonnet-4")
|
||||
@@ -1506,7 +1532,7 @@ def config_command(args):
|
||||
print("Available commands:")
|
||||
print(" hermes config Show current configuration")
|
||||
print(" hermes config edit Open config in editor")
|
||||
print(" hermes config set K V Set a config value")
|
||||
print(" hermes config set <key> <value> Set a config value")
|
||||
print(" hermes config check Check for missing/outdated config")
|
||||
print(" hermes config migrate Update config with new options")
|
||||
print(" hermes config path Show config file path")
|
||||
|
||||
+173
-42
@@ -1,15 +1,14 @@
|
||||
"""
|
||||
Cron subcommand for hermes CLI.
|
||||
|
||||
Handles: hermes cron [list|status|tick]
|
||||
|
||||
Cronjobs are executed automatically by the gateway daemon (hermes gateway).
|
||||
Install the gateway as a service for background execution:
|
||||
hermes gateway install
|
||||
Handles standalone cron management commands like list, create, edit,
|
||||
pause/resume/run/remove, status, and tick.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
@@ -17,62 +16,87 @@ sys.path.insert(0, str(PROJECT_ROOT))
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
def _normalize_skills(single_skill=None, skills: Optional[Iterable[str]] = None) -> Optional[List[str]]:
|
||||
if skills is None:
|
||||
if single_skill is None:
|
||||
return None
|
||||
raw_items = [single_skill]
|
||||
else:
|
||||
raw_items = list(skills)
|
||||
|
||||
normalized: List[str] = []
|
||||
for item in raw_items:
|
||||
text = str(item or "").strip()
|
||||
if text and text not in normalized:
|
||||
normalized.append(text)
|
||||
return normalized
|
||||
|
||||
|
||||
def _cron_api(**kwargs):
|
||||
from tools.cronjob_tools import cronjob as cronjob_tool
|
||||
|
||||
return json.loads(cronjob_tool(**kwargs))
|
||||
|
||||
|
||||
def cron_list(show_all: bool = False):
|
||||
"""List all scheduled jobs."""
|
||||
from cron.jobs import list_jobs
|
||||
|
||||
|
||||
jobs = list_jobs(include_disabled=show_all)
|
||||
|
||||
|
||||
if not jobs:
|
||||
print(color("No scheduled jobs.", Colors.DIM))
|
||||
print(color("Create one with the /cron add command in chat, or via Telegram.", Colors.DIM))
|
||||
print(color("Create one with 'hermes cron create ...' or the /cron command in chat.", Colors.DIM))
|
||||
return
|
||||
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ Scheduled Jobs │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
print()
|
||||
|
||||
|
||||
for job in jobs:
|
||||
job_id = job.get("id", "?")[:8]
|
||||
name = job.get("name", "(unnamed)")
|
||||
schedule = job.get("schedule_display", job.get("schedule", {}).get("value", "?"))
|
||||
enabled = job.get("enabled", True)
|
||||
state = job.get("state", "scheduled" if job.get("enabled", True) else "paused")
|
||||
next_run = job.get("next_run_at", "?")
|
||||
|
||||
|
||||
repeat_info = job.get("repeat", {})
|
||||
repeat_times = repeat_info.get("times")
|
||||
repeat_completed = repeat_info.get("completed", 0)
|
||||
|
||||
if repeat_times:
|
||||
repeat_str = f"{repeat_completed}/{repeat_times}"
|
||||
else:
|
||||
repeat_str = "∞"
|
||||
|
||||
repeat_str = f"{repeat_completed}/{repeat_times}" if repeat_times else "∞"
|
||||
|
||||
deliver = job.get("deliver", ["local"])
|
||||
if isinstance(deliver, str):
|
||||
deliver = [deliver]
|
||||
deliver_str = ", ".join(deliver)
|
||||
|
||||
if not enabled:
|
||||
status = color("[disabled]", Colors.RED)
|
||||
else:
|
||||
|
||||
skills = job.get("skills") or ([job["skill"]] if job.get("skill") else [])
|
||||
if state == "paused":
|
||||
status = color("[paused]", Colors.YELLOW)
|
||||
elif state == "completed":
|
||||
status = color("[completed]", Colors.BLUE)
|
||||
elif job.get("enabled", True):
|
||||
status = color("[active]", Colors.GREEN)
|
||||
|
||||
else:
|
||||
status = color("[disabled]", Colors.RED)
|
||||
|
||||
print(f" {color(job_id, Colors.YELLOW)} {status}")
|
||||
print(f" Name: {name}")
|
||||
print(f" Schedule: {schedule}")
|
||||
print(f" Repeat: {repeat_str}")
|
||||
print(f" Next run: {next_run}")
|
||||
print(f" Deliver: {deliver_str}")
|
||||
if skills:
|
||||
print(f" Skills: {', '.join(skills)}")
|
||||
print()
|
||||
|
||||
# Warn if gateway isn't running
|
||||
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
if not find_gateway_pids():
|
||||
print(color(" ⚠ Gateway is not running — jobs won't fire automatically.", Colors.YELLOW))
|
||||
print(color(" Start it with: hermes gateway install", Colors.DIM))
|
||||
print(color(" sudo hermes gateway install --system # Linux servers", Colors.DIM))
|
||||
print()
|
||||
|
||||
|
||||
@@ -86,9 +110,9 @@ def cron_status():
|
||||
"""Show cron execution status."""
|
||||
from cron.jobs import list_jobs
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
|
||||
|
||||
print()
|
||||
|
||||
|
||||
pids = find_gateway_pids()
|
||||
if pids:
|
||||
print(color("✓ Gateway is running — cron jobs will fire automatically", Colors.GREEN))
|
||||
@@ -97,11 +121,12 @@ def cron_status():
|
||||
print(color("✗ Gateway is not running — cron jobs will NOT fire", Colors.RED))
|
||||
print()
|
||||
print(" To enable automatic execution:")
|
||||
print(" hermes gateway install # Install as system service (recommended)")
|
||||
print(" hermes gateway install # Install as a user service")
|
||||
print(" sudo hermes gateway install --system # Linux servers: boot-time system service")
|
||||
print(" hermes gateway # Or run in foreground")
|
||||
|
||||
|
||||
print()
|
||||
|
||||
|
||||
jobs = list_jobs(include_disabled=False)
|
||||
if jobs:
|
||||
next_runs = [j.get("next_run_at") for j in jobs if j.get("next_run_at")]
|
||||
@@ -110,25 +135,131 @@ def cron_status():
|
||||
print(f" Next run: {min(next_runs)}")
|
||||
else:
|
||||
print(" No active jobs")
|
||||
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def cron_create(args):
|
||||
result = _cron_api(
|
||||
action="create",
|
||||
schedule=args.schedule,
|
||||
prompt=args.prompt,
|
||||
name=getattr(args, "name", None),
|
||||
deliver=getattr(args, "deliver", None),
|
||||
repeat=getattr(args, "repeat", None),
|
||||
skill=getattr(args, "skill", None),
|
||||
skills=_normalize_skills(getattr(args, "skill", None), getattr(args, "skills", None)),
|
||||
)
|
||||
if not result.get("success"):
|
||||
print(color(f"Failed to create job: {result.get('error', 'unknown error')}", Colors.RED))
|
||||
return 1
|
||||
print(color(f"Created job: {result['job_id']}", Colors.GREEN))
|
||||
print(f" Name: {result['name']}")
|
||||
print(f" Schedule: {result['schedule']}")
|
||||
if result.get("skills"):
|
||||
print(f" Skills: {', '.join(result['skills'])}")
|
||||
print(f" Next run: {result['next_run_at']}")
|
||||
return 0
|
||||
|
||||
|
||||
def cron_edit(args):
|
||||
from cron.jobs import get_job
|
||||
|
||||
job = get_job(args.job_id)
|
||||
if not job:
|
||||
print(color(f"Job not found: {args.job_id}", Colors.RED))
|
||||
return 1
|
||||
|
||||
existing_skills = list(job.get("skills") or ([] if not job.get("skill") else [job.get("skill")]))
|
||||
replacement_skills = _normalize_skills(getattr(args, "skill", None), getattr(args, "skills", None))
|
||||
add_skills = _normalize_skills(None, getattr(args, "add_skills", None)) or []
|
||||
remove_skills = set(_normalize_skills(None, getattr(args, "remove_skills", None)) or [])
|
||||
|
||||
final_skills = None
|
||||
if getattr(args, "clear_skills", False):
|
||||
final_skills = []
|
||||
elif replacement_skills is not None:
|
||||
final_skills = replacement_skills
|
||||
elif add_skills or remove_skills:
|
||||
final_skills = [skill for skill in existing_skills if skill not in remove_skills]
|
||||
for skill in add_skills:
|
||||
if skill not in final_skills:
|
||||
final_skills.append(skill)
|
||||
|
||||
result = _cron_api(
|
||||
action="update",
|
||||
job_id=args.job_id,
|
||||
schedule=getattr(args, "schedule", None),
|
||||
prompt=getattr(args, "prompt", None),
|
||||
name=getattr(args, "name", None),
|
||||
deliver=getattr(args, "deliver", None),
|
||||
repeat=getattr(args, "repeat", None),
|
||||
skills=final_skills,
|
||||
)
|
||||
if not result.get("success"):
|
||||
print(color(f"Failed to update job: {result.get('error', 'unknown error')}", Colors.RED))
|
||||
return 1
|
||||
|
||||
updated = result["job"]
|
||||
print(color(f"Updated job: {updated['job_id']}", Colors.GREEN))
|
||||
print(f" Name: {updated['name']}")
|
||||
print(f" Schedule: {updated['schedule']}")
|
||||
if updated.get("skills"):
|
||||
print(f" Skills: {', '.join(updated['skills'])}")
|
||||
else:
|
||||
print(" Skills: none")
|
||||
return 0
|
||||
|
||||
|
||||
def _job_action(action: str, job_id: str, success_verb: str) -> int:
|
||||
result = _cron_api(action=action, job_id=job_id)
|
||||
if not result.get("success"):
|
||||
print(color(f"Failed to {action} job: {result.get('error', 'unknown error')}", Colors.RED))
|
||||
return 1
|
||||
job = result.get("job") or result.get("removed_job") or {}
|
||||
print(color(f"{success_verb} job: {job.get('name', job_id)} ({job_id})", Colors.GREEN))
|
||||
if action in {"resume", "run"} and result.get("job", {}).get("next_run_at"):
|
||||
print(f" Next run: {result['job']['next_run_at']}")
|
||||
if action == "run":
|
||||
print(" It will run on the next scheduler tick.")
|
||||
return 0
|
||||
|
||||
|
||||
def cron_command(args):
|
||||
"""Handle cron subcommands."""
|
||||
subcmd = getattr(args, 'cron_command', None)
|
||||
|
||||
|
||||
if subcmd is None or subcmd == "list":
|
||||
show_all = getattr(args, 'all', False)
|
||||
cron_list(show_all)
|
||||
|
||||
elif subcmd == "tick":
|
||||
cron_tick()
|
||||
|
||||
elif subcmd == "status":
|
||||
return 0
|
||||
|
||||
if subcmd == "status":
|
||||
cron_status()
|
||||
|
||||
else:
|
||||
print(f"Unknown cron command: {subcmd}")
|
||||
print("Usage: hermes cron [list|status|tick]")
|
||||
sys.exit(1)
|
||||
return 0
|
||||
|
||||
if subcmd == "tick":
|
||||
cron_tick()
|
||||
return 0
|
||||
|
||||
if subcmd in {"create", "add"}:
|
||||
return cron_create(args)
|
||||
|
||||
if subcmd == "edit":
|
||||
return cron_edit(args)
|
||||
|
||||
if subcmd == "pause":
|
||||
return _job_action("pause", args.job_id, "Paused")
|
||||
|
||||
if subcmd == "resume":
|
||||
return _job_action("resume", args.job_id, "Resumed")
|
||||
|
||||
if subcmd == "run":
|
||||
return _job_action("run", args.job_id, "Triggered")
|
||||
|
||||
if subcmd in {"remove", "rm", "delete"}:
|
||||
return _job_action("remove", args.job_id, "Removed")
|
||||
|
||||
print(f"Unknown cron command: {subcmd}")
|
||||
print("Usage: hermes cron [list|create|edit|pause|resume|run|remove|status|tick]")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Helpers for loading Hermes .env files consistently across entrypoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
def _load_dotenv_with_fallback(path: Path, *, override: bool) -> None:
|
||||
try:
|
||||
load_dotenv(dotenv_path=path, override=override, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=path, override=override, encoding="latin-1")
|
||||
|
||||
|
||||
def load_hermes_dotenv(
|
||||
*,
|
||||
hermes_home: str | os.PathLike | None = None,
|
||||
project_env: str | os.PathLike | None = None,
|
||||
) -> list[Path]:
|
||||
"""Load Hermes environment files with user config taking precedence.
|
||||
|
||||
Behavior:
|
||||
- `~/.hermes/.env` overrides stale shell-exported values when present.
|
||||
- project `.env` acts as a dev fallback and only fills missing values when
|
||||
the user env exists.
|
||||
- if no user env exists, the project `.env` also overrides stale shell vars.
|
||||
"""
|
||||
loaded: list[Path] = []
|
||||
|
||||
home_path = Path(hermes_home or os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
user_env = home_path / ".env"
|
||||
project_env_path = Path(project_env) if project_env else None
|
||||
|
||||
if user_env.exists():
|
||||
_load_dotenv_with_fallback(user_env, override=True)
|
||||
loaded.append(user_env)
|
||||
|
||||
if project_env_path and project_env_path.exists():
|
||||
_load_dotenv_with_fallback(project_env_path, override=not loaded)
|
||||
loaded.append(project_env_path)
|
||||
|
||||
return loaded
|
||||
+439
-92
@@ -123,10 +123,143 @@ SERVICE_NAME = "hermes-gateway"
|
||||
SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
|
||||
|
||||
|
||||
def get_systemd_unit_path() -> Path:
|
||||
def get_systemd_unit_path(system: bool = False) -> Path:
|
||||
if system:
|
||||
return Path("/etc/systemd/system") / f"{SERVICE_NAME}.service"
|
||||
return Path.home() / ".config" / "systemd" / "user" / f"{SERVICE_NAME}.service"
|
||||
|
||||
|
||||
def _systemctl_cmd(system: bool = False) -> list[str]:
|
||||
return ["systemctl"] if system else ["systemctl", "--user"]
|
||||
|
||||
|
||||
def _journalctl_cmd(system: bool = False) -> list[str]:
|
||||
return ["journalctl"] if system else ["journalctl", "--user"]
|
||||
|
||||
|
||||
def _service_scope_label(system: bool = False) -> str:
|
||||
return "system" if system else "user"
|
||||
|
||||
|
||||
def get_installed_systemd_scopes() -> list[str]:
|
||||
scopes = []
|
||||
seen_paths: set[Path] = set()
|
||||
for system, label in ((False, "user"), (True, "system")):
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
if unit_path in seen_paths:
|
||||
continue
|
||||
if unit_path.exists():
|
||||
scopes.append(label)
|
||||
seen_paths.add(unit_path)
|
||||
return scopes
|
||||
|
||||
|
||||
def has_conflicting_systemd_units() -> bool:
|
||||
return len(get_installed_systemd_scopes()) > 1
|
||||
|
||||
|
||||
def print_systemd_scope_conflict_warning() -> None:
|
||||
scopes = get_installed_systemd_scopes()
|
||||
if len(scopes) < 2:
|
||||
return
|
||||
|
||||
rendered_scopes = " + ".join(scopes)
|
||||
print_warning(f"Both user and system gateway services are installed ({rendered_scopes}).")
|
||||
print_info(" This is confusing and can make start/stop/status behavior ambiguous.")
|
||||
print_info(" Default gateway commands target the user service unless you pass --system.")
|
||||
print_info(" Keep one of these:")
|
||||
print_info(" hermes gateway uninstall")
|
||||
print_info(" sudo hermes gateway uninstall --system")
|
||||
|
||||
|
||||
def _require_root_for_system_service(action: str) -> None:
|
||||
if os.geteuid() != 0:
|
||||
print(f"System gateway {action} requires root. Re-run with sudo.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _system_service_identity(run_as_user: str | None = None) -> tuple[str, str, str]:
|
||||
import getpass
|
||||
import grp
|
||||
import pwd
|
||||
|
||||
username = (run_as_user or os.getenv("SUDO_USER") or os.getenv("USER") or os.getenv("LOGNAME") or getpass.getuser()).strip()
|
||||
if not username:
|
||||
raise ValueError("Could not determine which user the gateway service should run as")
|
||||
if username == "root":
|
||||
raise ValueError("Refusing to install the gateway system service as root; pass --run-as USER")
|
||||
|
||||
try:
|
||||
user_info = pwd.getpwnam(username)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown user: {username}") from e
|
||||
|
||||
group_name = grp.getgrgid(user_info.pw_gid).gr_name
|
||||
return username, group_name, user_info.pw_dir
|
||||
|
||||
|
||||
def _read_systemd_user_from_unit(unit_path: Path) -> str | None:
|
||||
if not unit_path.exists():
|
||||
return None
|
||||
|
||||
for line in unit_path.read_text(encoding="utf-8").splitlines():
|
||||
if line.startswith("User="):
|
||||
value = line.split("=", 1)[1].strip()
|
||||
return value or None
|
||||
return None
|
||||
|
||||
|
||||
def _default_system_service_user() -> str | None:
|
||||
for candidate in (os.getenv("SUDO_USER"), os.getenv("USER"), os.getenv("LOGNAME")):
|
||||
if candidate and candidate.strip() and candidate.strip() != "root":
|
||||
return candidate.strip()
|
||||
return None
|
||||
|
||||
|
||||
def prompt_linux_gateway_install_scope() -> str | None:
|
||||
choice = prompt_choice(
|
||||
" Choose how the gateway should run in the background:",
|
||||
[
|
||||
"User service (no sudo; best for laptops/dev boxes; may need linger after logout)",
|
||||
"System service (starts on boot; requires sudo; still runs as your user)",
|
||||
"Skip service install for now",
|
||||
],
|
||||
default=0,
|
||||
)
|
||||
return {0: "user", 1: "system", 2: None}[choice]
|
||||
|
||||
|
||||
def install_linux_gateway_from_setup(force: bool = False) -> tuple[str | None, bool]:
|
||||
scope = prompt_linux_gateway_install_scope()
|
||||
if scope is None:
|
||||
return None, False
|
||||
|
||||
if scope == "system":
|
||||
run_as_user = _default_system_service_user()
|
||||
if os.geteuid() != 0:
|
||||
print_warning(" System service install requires sudo, so Hermes can't create it from this user session.")
|
||||
if run_as_user:
|
||||
print_info(f" After setup, run: sudo hermes gateway install --system --run-as-user {run_as_user}")
|
||||
else:
|
||||
print_info(" After setup, run: sudo hermes gateway install --system --run-as-user <your-user>")
|
||||
print_info(" Then start it with: sudo hermes gateway start --system")
|
||||
return scope, False
|
||||
|
||||
if not run_as_user:
|
||||
while True:
|
||||
run_as_user = prompt(" Run the system gateway service as which user?", default="")
|
||||
run_as_user = (run_as_user or "").strip()
|
||||
if run_as_user and run_as_user != "root":
|
||||
break
|
||||
print_error(" Enter a non-root username.")
|
||||
|
||||
systemd_install(force=force, system=True, run_as_user=run_as_user)
|
||||
return scope, True
|
||||
|
||||
systemd_install(force=force, system=False)
|
||||
return scope, True
|
||||
|
||||
|
||||
def get_systemd_linger_status() -> tuple[bool | None, str]:
|
||||
"""Return whether systemd user lingering is enabled for the current user.
|
||||
|
||||
@@ -216,8 +349,9 @@ def get_hermes_cli_path() -> str:
|
||||
# Systemd (Linux)
|
||||
# =============================================================================
|
||||
|
||||
def generate_systemd_unit() -> str:
|
||||
def generate_systemd_unit(system: bool = False, run_as_user: str | None = None) -> str:
|
||||
import shutil
|
||||
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
venv_dir = str(PROJECT_ROOT / "venv")
|
||||
@@ -226,8 +360,38 @@ def generate_systemd_unit() -> str:
|
||||
|
||||
# Build a PATH that includes the venv, node_modules, and standard system dirs
|
||||
sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
|
||||
hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main"
|
||||
|
||||
if system:
|
||||
username, group_name, home_dir = _system_service_identity(run_as_user)
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network-online.target
|
||||
Wants=network-online.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User={username}
|
||||
Group={group_name}
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run --replace
|
||||
WorkingDirectory={working_dir}
|
||||
Environment="HOME={home_dir}"
|
||||
Environment="USER={username}"
|
||||
Environment="LOGNAME={username}"
|
||||
Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=15
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
"""
|
||||
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network.target
|
||||
@@ -251,123 +415,236 @@ StandardError=journal
|
||||
WantedBy=default.target
|
||||
"""
|
||||
|
||||
|
||||
def _normalize_service_definition(text: str) -> str:
|
||||
return "\n".join(line.rstrip() for line in text.strip().splitlines())
|
||||
|
||||
|
||||
def systemd_unit_is_current() -> bool:
|
||||
unit_path = get_systemd_unit_path()
|
||||
def systemd_unit_is_current(system: bool = False) -> bool:
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
if not unit_path.exists():
|
||||
return False
|
||||
|
||||
installed = unit_path.read_text(encoding="utf-8")
|
||||
expected = generate_systemd_unit()
|
||||
expected_user = _read_systemd_user_from_unit(unit_path) if system else None
|
||||
expected = generate_systemd_unit(system=system, run_as_user=expected_user)
|
||||
return _normalize_service_definition(installed) == _normalize_service_definition(expected)
|
||||
|
||||
|
||||
|
||||
def refresh_systemd_unit_if_needed() -> bool:
|
||||
"""Rewrite the installed user unit when the generated definition has changed."""
|
||||
unit_path = get_systemd_unit_path()
|
||||
if not unit_path.exists() or systemd_unit_is_current():
|
||||
def refresh_systemd_unit_if_needed(system: bool = False) -> bool:
|
||||
"""Rewrite the installed systemd unit when the generated definition has changed."""
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
if not unit_path.exists() or systemd_unit_is_current(system=system):
|
||||
return False
|
||||
|
||||
unit_path.write_text(generate_systemd_unit(), encoding="utf-8")
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
print("↻ Updated gateway service definition to match the current Hermes install")
|
||||
expected_user = _read_systemd_user_from_unit(unit_path) if system else None
|
||||
unit_path.write_text(generate_systemd_unit(system=system, run_as_user=expected_user), encoding="utf-8")
|
||||
subprocess.run(_systemctl_cmd(system) + ["daemon-reload"], check=True)
|
||||
print(f"↻ Updated gateway {_service_scope_label(system)} service definition to match the current Hermes install")
|
||||
return True
|
||||
|
||||
|
||||
def systemd_install(force: bool = False):
|
||||
unit_path = get_systemd_unit_path()
|
||||
|
||||
|
||||
def _print_linger_enable_warning(username: str, detail: str | None = None) -> None:
|
||||
print()
|
||||
print("⚠ Linger not enabled — gateway may stop when you close this terminal.")
|
||||
if detail:
|
||||
print(f" Auto-enable failed: {detail}")
|
||||
print()
|
||||
print(" On headless servers (VPS, cloud instances) run:")
|
||||
print(f" sudo loginctl enable-linger {username}")
|
||||
print()
|
||||
print(" Then restart the gateway:")
|
||||
print(f" systemctl --user restart {SERVICE_NAME}.service")
|
||||
print()
|
||||
|
||||
|
||||
|
||||
def _ensure_linger_enabled() -> None:
|
||||
"""Enable linger when possible so the user gateway survives logout."""
|
||||
if not is_linux():
|
||||
return
|
||||
|
||||
import getpass
|
||||
import shutil
|
||||
|
||||
username = getpass.getuser()
|
||||
linger_file = Path(f"/var/lib/systemd/linger/{username}")
|
||||
if linger_file.exists():
|
||||
print("✓ Systemd linger is enabled (service survives logout)")
|
||||
return
|
||||
|
||||
linger_enabled, linger_detail = get_systemd_linger_status()
|
||||
if linger_enabled is True:
|
||||
print("✓ Systemd linger is enabled (service survives logout)")
|
||||
return
|
||||
|
||||
if not shutil.which("loginctl"):
|
||||
_print_linger_enable_warning(username, linger_detail or "loginctl not found")
|
||||
return
|
||||
|
||||
print("Enabling linger so the gateway survives SSH logout...")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["loginctl", "enable-linger", username],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except Exception as e:
|
||||
_print_linger_enable_warning(username, str(e))
|
||||
return
|
||||
|
||||
if result.returncode == 0:
|
||||
print("✓ Linger enabled — gateway will persist after logout")
|
||||
return
|
||||
|
||||
detail = (result.stderr or result.stdout or f"exit {result.returncode}").strip()
|
||||
_print_linger_enable_warning(username, detail or linger_detail)
|
||||
|
||||
|
||||
def _select_systemd_scope(system: bool = False) -> bool:
|
||||
if system:
|
||||
return True
|
||||
return get_systemd_unit_path(system=True).exists() and not get_systemd_unit_path(system=False).exists()
|
||||
|
||||
|
||||
def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None):
|
||||
if system:
|
||||
_require_root_for_system_service("install")
|
||||
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
scope_flag = " --system" if system else ""
|
||||
|
||||
if unit_path.exists() and not force:
|
||||
print(f"Service already installed at: {unit_path}")
|
||||
print("Use --force to reinstall")
|
||||
return
|
||||
|
||||
|
||||
unit_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Installing systemd service to: {unit_path}")
|
||||
unit_path.write_text(generate_systemd_unit())
|
||||
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
subprocess.run(["systemctl", "--user", "enable", SERVICE_NAME], check=True)
|
||||
|
||||
print(f"Installing {_service_scope_label(system)} systemd service to: {unit_path}")
|
||||
unit_path.write_text(generate_systemd_unit(system=system, run_as_user=run_as_user), encoding="utf-8")
|
||||
|
||||
subprocess.run(_systemctl_cmd(system) + ["daemon-reload"], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["enable", SERVICE_NAME], check=True)
|
||||
|
||||
print()
|
||||
print("✓ Service installed and enabled!")
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service installed and enabled!")
|
||||
print()
|
||||
print("Next steps:")
|
||||
print(f" hermes gateway start # Start the service")
|
||||
print(f" hermes gateway status # Check status")
|
||||
print(f" journalctl --user -u {SERVICE_NAME} -f # View logs")
|
||||
print(f" {'sudo ' if system else ''}hermes gateway start{scope_flag} # Start the service")
|
||||
print(f" {'sudo ' if system else ''}hermes gateway status{scope_flag} # Check status")
|
||||
print(f" {'journalctl' if system else 'journalctl --user'} -u {SERVICE_NAME} -f # View logs")
|
||||
print()
|
||||
print_systemd_linger_guidance()
|
||||
|
||||
def systemd_uninstall():
|
||||
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=False)
|
||||
subprocess.run(["systemctl", "--user", "disable", SERVICE_NAME], check=False)
|
||||
|
||||
unit_path = get_systemd_unit_path()
|
||||
if system:
|
||||
configured_user = _read_systemd_user_from_unit(unit_path)
|
||||
if configured_user:
|
||||
print(f"Configured to run as: {configured_user}")
|
||||
else:
|
||||
_ensure_linger_enabled()
|
||||
|
||||
print_systemd_scope_conflict_warning()
|
||||
|
||||
|
||||
def systemd_uninstall(system: bool = False):
|
||||
system = _select_systemd_scope(system)
|
||||
if system:
|
||||
_require_root_for_system_service("uninstall")
|
||||
|
||||
subprocess.run(_systemctl_cmd(system) + ["stop", SERVICE_NAME], check=False)
|
||||
subprocess.run(_systemctl_cmd(system) + ["disable", SERVICE_NAME], check=False)
|
||||
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
if unit_path.exists():
|
||||
unit_path.unlink()
|
||||
print(f"✓ Removed {unit_path}")
|
||||
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
print("✓ Service uninstalled")
|
||||
|
||||
def systemd_start():
|
||||
refresh_systemd_unit_if_needed()
|
||||
subprocess.run(["systemctl", "--user", "start", SERVICE_NAME], check=True)
|
||||
print("✓ Service started")
|
||||
subprocess.run(_systemctl_cmd(system) + ["daemon-reload"], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service uninstalled")
|
||||
|
||||
|
||||
def systemd_stop():
|
||||
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=True)
|
||||
print("✓ Service stopped")
|
||||
def systemd_start(system: bool = False):
|
||||
system = _select_systemd_scope(system)
|
||||
if system:
|
||||
_require_root_for_system_service("start")
|
||||
refresh_systemd_unit_if_needed(system=system)
|
||||
subprocess.run(_systemctl_cmd(system) + ["start", SERVICE_NAME], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service started")
|
||||
|
||||
|
||||
def systemd_restart():
|
||||
refresh_systemd_unit_if_needed()
|
||||
subprocess.run(["systemctl", "--user", "restart", SERVICE_NAME], check=True)
|
||||
print("✓ Service restarted")
|
||||
|
||||
def systemd_stop(system: bool = False):
|
||||
system = _select_systemd_scope(system)
|
||||
if system:
|
||||
_require_root_for_system_service("stop")
|
||||
subprocess.run(_systemctl_cmd(system) + ["stop", SERVICE_NAME], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service stopped")
|
||||
|
||||
|
||||
def systemd_status(deep: bool = False):
|
||||
# Check if service unit file exists
|
||||
unit_path = get_systemd_unit_path()
|
||||
|
||||
def systemd_restart(system: bool = False):
|
||||
system = _select_systemd_scope(system)
|
||||
if system:
|
||||
_require_root_for_system_service("restart")
|
||||
refresh_systemd_unit_if_needed(system=system)
|
||||
subprocess.run(_systemctl_cmd(system) + ["restart", SERVICE_NAME], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service restarted")
|
||||
|
||||
|
||||
|
||||
def systemd_status(deep: bool = False, system: bool = False):
|
||||
system = _select_systemd_scope(system)
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
scope_flag = " --system" if system else ""
|
||||
|
||||
if not unit_path.exists():
|
||||
print("✗ Gateway service is not installed")
|
||||
print(" Run: hermes gateway install")
|
||||
print(f" Run: {'sudo ' if system else ''}hermes gateway install{scope_flag}")
|
||||
return
|
||||
|
||||
if not systemd_unit_is_current():
|
||||
print("⚠ Installed gateway service definition is outdated")
|
||||
print(" Run: hermes gateway restart # auto-refreshes the unit")
|
||||
if has_conflicting_systemd_units():
|
||||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
# Show detailed status first
|
||||
|
||||
if not systemd_unit_is_current(system=system):
|
||||
print("⚠ Installed gateway service definition is outdated")
|
||||
print(f" Run: {'sudo ' if system else ''}hermes gateway restart{scope_flag} # auto-refreshes the unit")
|
||||
print()
|
||||
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "status", SERVICE_NAME, "--no-pager"],
|
||||
capture_output=False
|
||||
_systemctl_cmd(system) + ["status", SERVICE_NAME, "--no-pager"],
|
||||
capture_output=False,
|
||||
)
|
||||
|
||||
# Check if service is active
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", SERVICE_NAME],
|
||||
_systemctl_cmd(system) + ["is-active", SERVICE_NAME],
|
||||
capture_output=True,
|
||||
text=True
|
||||
text=True,
|
||||
)
|
||||
|
||||
status = result.stdout.strip()
|
||||
|
||||
if status == "active":
|
||||
print("✓ Gateway service is running")
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} gateway service is running")
|
||||
else:
|
||||
print("✗ Gateway service is stopped")
|
||||
print(" Run: hermes gateway start")
|
||||
print(f"✗ {_service_scope_label(system).capitalize()} gateway service is stopped")
|
||||
print(f" Run: {'sudo ' if system else ''}hermes gateway start{scope_flag}")
|
||||
|
||||
if deep:
|
||||
configured_user = _read_systemd_user_from_unit(unit_path) if system else None
|
||||
if configured_user:
|
||||
print(f"Configured to run as: {configured_user}")
|
||||
|
||||
runtime_lines = _runtime_health_lines()
|
||||
if runtime_lines:
|
||||
print()
|
||||
print("Recent gateway health:")
|
||||
for line in runtime_lines:
|
||||
print(f" {line}")
|
||||
|
||||
if system:
|
||||
print("✓ System service starts at boot without requiring systemd linger")
|
||||
elif deep:
|
||||
print_systemd_linger_guidance()
|
||||
else:
|
||||
linger_enabled, _ = get_systemd_linger_status()
|
||||
@@ -380,10 +657,7 @@ def systemd_status(deep: bool = False):
|
||||
if deep:
|
||||
print()
|
||||
print("Recent logs:")
|
||||
subprocess.run([
|
||||
"journalctl", "--user", "-u", SERVICE_NAME,
|
||||
"-n", "20", "--no-pager"
|
||||
])
|
||||
subprocess.run(_journalctl_cmd(system) + ["-u", SERVICE_NAME, "-n", "20", "--no-pager"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -693,6 +967,35 @@ def _platform_status(platform: dict) -> str:
|
||||
return "not configured"
|
||||
|
||||
|
||||
def _runtime_health_lines() -> list[str]:
|
||||
"""Summarize the latest persisted gateway runtime health state."""
|
||||
try:
|
||||
from gateway.status import read_runtime_status
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
state = read_runtime_status()
|
||||
if not state:
|
||||
return []
|
||||
|
||||
lines: list[str] = []
|
||||
gateway_state = state.get("gateway_state")
|
||||
exit_reason = state.get("exit_reason")
|
||||
platforms = state.get("platforms", {}) or {}
|
||||
|
||||
for platform, pdata in platforms.items():
|
||||
if pdata.get("state") == "fatal":
|
||||
message = pdata.get("error_message") or "unknown error"
|
||||
lines.append(f"⚠ {platform}: {message}")
|
||||
|
||||
if gateway_state == "startup_failed" and exit_reason:
|
||||
lines.append(f"⚠ Last startup issue: {exit_reason}")
|
||||
elif gateway_state == "stopped" and exit_reason:
|
||||
lines.append(f"⚠ Last shutdown reason: {exit_reason}")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _setup_standard_platform(platform: dict):
|
||||
"""Interactive setup for Telegram, Discord, or Slack."""
|
||||
emoji = platform["emoji"]
|
||||
@@ -801,7 +1104,7 @@ def _setup_whatsapp():
|
||||
def _is_service_installed() -> bool:
|
||||
"""Check if the gateway is installed as a system service."""
|
||||
if is_linux():
|
||||
return get_systemd_unit_path().exists()
|
||||
return get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()
|
||||
elif is_macos():
|
||||
return get_launchd_plist_path().exists()
|
||||
return False
|
||||
@@ -809,12 +1112,27 @@ def _is_service_installed() -> bool:
|
||||
|
||||
def _is_service_running() -> bool:
|
||||
"""Check if the gateway service is currently running."""
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", SERVICE_NAME],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return result.stdout.strip() == "active"
|
||||
if is_linux():
|
||||
user_unit_exists = get_systemd_unit_path(system=False).exists()
|
||||
system_unit_exists = get_systemd_unit_path(system=True).exists()
|
||||
|
||||
if user_unit_exists:
|
||||
result = subprocess.run(
|
||||
_systemctl_cmd(False) + ["is-active", SERVICE_NAME],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.stdout.strip() == "active":
|
||||
return True
|
||||
|
||||
if system_unit_exists:
|
||||
result = subprocess.run(
|
||||
_systemctl_cmd(True) + ["is-active", SERVICE_NAME],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.stdout.strip() == "active":
|
||||
return True
|
||||
|
||||
return False
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
@@ -956,6 +1274,10 @@ def gateway_setup():
|
||||
service_installed = _is_service_installed()
|
||||
service_running = _is_service_running()
|
||||
|
||||
if is_linux() and has_conflicting_systemd_units():
|
||||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if service_installed and service_running:
|
||||
print_success("Gateway service is installed and running.")
|
||||
elif service_installed:
|
||||
@@ -1037,16 +1359,18 @@ def gateway_setup():
|
||||
platform_name = "systemd" if is_linux() else "launchd"
|
||||
if prompt_yes_no(f" Install the gateway as a {platform_name} service? (runs in background, starts on boot)", True):
|
||||
try:
|
||||
force = False
|
||||
installed_scope = None
|
||||
did_install = False
|
||||
if is_linux():
|
||||
systemd_install(force)
|
||||
installed_scope, did_install = install_linux_gateway_from_setup(force=False)
|
||||
else:
|
||||
launchd_install(force)
|
||||
launchd_install(force=False)
|
||||
did_install = True
|
||||
print()
|
||||
if prompt_yes_no(" Start the service now?", True):
|
||||
if did_install and prompt_yes_no(" Start the service now?", True):
|
||||
try:
|
||||
if is_linux():
|
||||
systemd_start()
|
||||
systemd_start(system=installed_scope == "system")
|
||||
else:
|
||||
launchd_start()
|
||||
except subprocess.CalledProcessError as e:
|
||||
@@ -1056,6 +1380,8 @@ def gateway_setup():
|
||||
print_info(" You can try manually: hermes gateway install")
|
||||
else:
|
||||
print_info(" You can install later: hermes gateway install")
|
||||
if is_linux():
|
||||
print_info(" Or as a boot-time service: sudo hermes gateway install --system")
|
||||
print_info(" Or run in foreground: hermes gateway")
|
||||
else:
|
||||
print_info(" Service install not supported on this platform.")
|
||||
@@ -1089,8 +1415,10 @@ def gateway_command(args):
|
||||
# Service management commands
|
||||
if subcmd == "install":
|
||||
force = getattr(args, 'force', False)
|
||||
system = getattr(args, 'system', False)
|
||||
run_as_user = getattr(args, 'run_as_user', None)
|
||||
if is_linux():
|
||||
systemd_install(force)
|
||||
systemd_install(force=force, system=system, run_as_user=run_as_user)
|
||||
elif is_macos():
|
||||
launchd_install(force)
|
||||
else:
|
||||
@@ -1099,8 +1427,9 @@ def gateway_command(args):
|
||||
sys.exit(1)
|
||||
|
||||
elif subcmd == "uninstall":
|
||||
system = getattr(args, 'system', False)
|
||||
if is_linux():
|
||||
systemd_uninstall()
|
||||
systemd_uninstall(system=system)
|
||||
elif is_macos():
|
||||
launchd_uninstall()
|
||||
else:
|
||||
@@ -1108,8 +1437,9 @@ def gateway_command(args):
|
||||
sys.exit(1)
|
||||
|
||||
elif subcmd == "start":
|
||||
system = getattr(args, 'system', False)
|
||||
if is_linux():
|
||||
systemd_start()
|
||||
systemd_start(system=system)
|
||||
elif is_macos():
|
||||
launchd_start()
|
||||
else:
|
||||
@@ -1119,10 +1449,11 @@ def gateway_command(args):
|
||||
elif subcmd == "stop":
|
||||
# Try service first, then sweep any stray/manual gateway processes.
|
||||
service_available = False
|
||||
system = getattr(args, 'system', False)
|
||||
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
if is_linux() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
try:
|
||||
systemd_stop()
|
||||
systemd_stop(system=system)
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass # Fall through to process kill
|
||||
@@ -1145,10 +1476,11 @@ def gateway_command(args):
|
||||
elif subcmd == "restart":
|
||||
# Try service first, fall back to killing and restarting
|
||||
service_available = False
|
||||
system = getattr(args, 'system', False)
|
||||
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
if is_linux() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
try:
|
||||
systemd_restart()
|
||||
systemd_restart(system=system)
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
@@ -1174,10 +1506,11 @@ def gateway_command(args):
|
||||
|
||||
elif subcmd == "status":
|
||||
deep = getattr(args, 'deep', False)
|
||||
system = getattr(args, 'system', False)
|
||||
|
||||
# Check for service first
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
systemd_status(deep)
|
||||
if is_linux() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
systemd_status(deep, system=system)
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
launchd_status(deep)
|
||||
else:
|
||||
@@ -1186,12 +1519,26 @@ def gateway_command(args):
|
||||
if pids:
|
||||
print(f"✓ Gateway is running (PID: {', '.join(map(str, pids))})")
|
||||
print(" (Running manually, not as a system service)")
|
||||
runtime_lines = _runtime_health_lines()
|
||||
if runtime_lines:
|
||||
print()
|
||||
print("Recent gateway health:")
|
||||
for line in runtime_lines:
|
||||
print(f" {line}")
|
||||
print()
|
||||
print("To install as a service:")
|
||||
print(" hermes gateway install")
|
||||
print(" sudo hermes gateway install --system")
|
||||
else:
|
||||
print("✗ Gateway is not running")
|
||||
runtime_lines = _runtime_health_lines()
|
||||
if runtime_lines:
|
||||
print()
|
||||
print("Recent gateway health:")
|
||||
for line in runtime_lines:
|
||||
print(f" {line}")
|
||||
print()
|
||||
print("To start:")
|
||||
print(" hermes gateway # Run in foreground")
|
||||
print(" hermes gateway install # Install as service")
|
||||
print(" hermes gateway install # Install as user service")
|
||||
print(" sudo hermes gateway install --system # Install as boot-time system service")
|
||||
|
||||
+198
-28
@@ -54,16 +54,11 @@ from typing import Optional
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback
|
||||
from dotenv import load_dotenv
|
||||
from hermes_cli.config import get_env_path, get_hermes_home
|
||||
_user_env = get_env_path()
|
||||
if _user_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="latin-1")
|
||||
load_dotenv(dotenv_path=PROJECT_ROOT / '.env', override=False)
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
from hermes_cli.config import get_hermes_home
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
load_hermes_dotenv(project_env=PROJECT_ROOT / '.env')
|
||||
|
||||
# Point mini-swe-agent at ~/.hermes/ so it shares our config
|
||||
os.environ.setdefault("MSWEA_GLOBAL_CONFIG_DIR", str(get_hermes_home()))
|
||||
@@ -480,6 +475,13 @@ def cmd_chat(args):
|
||||
print("You can run 'hermes setup' at any time to configure.")
|
||||
sys.exit(1)
|
||||
|
||||
# Start update check in background (runs while other init happens)
|
||||
try:
|
||||
from hermes_cli.banner import prefetch_update_check
|
||||
prefetch_update_check()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Sync bundled skills on every CLI launch (fast -- skips unchanged skills)
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
@@ -499,6 +501,7 @@ def cmd_chat(args):
|
||||
"model": args.model,
|
||||
"provider": getattr(args, "provider", None),
|
||||
"toolsets": args.toolsets,
|
||||
"skills": getattr(args, "skills", None),
|
||||
"verbose": args.verbose,
|
||||
"quiet": getattr(args, "quiet", False),
|
||||
"query": args.query,
|
||||
@@ -510,7 +513,11 @@ def cmd_chat(args):
|
||||
# Filter out None values
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
cli_main(**kwargs)
|
||||
try:
|
||||
cli_main(**kwargs)
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_gateway(args):
|
||||
@@ -1368,6 +1375,12 @@ _PROVIDER_MODELS = {
|
||||
"kimi-k2-turbo-preview",
|
||||
"kimi-k2-0905-preview",
|
||||
],
|
||||
"moonshot": [
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-turbo-preview",
|
||||
"kimi-k2-0905-preview",
|
||||
],
|
||||
"minimax": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
@@ -1449,8 +1462,8 @@ def _model_flow_kimi(config, current_model=""):
|
||||
"kimi-k2-thinking-turbo",
|
||||
]
|
||||
else:
|
||||
# Legacy Moonshot models
|
||||
model_list = _PROVIDER_MODELS.get(provider_id, [])
|
||||
# Legacy Moonshot models (excludes Coding Plan-only models)
|
||||
model_list = _PROVIDER_MODELS.get("moonshot", [])
|
||||
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(model_list, current_model=current_model)
|
||||
@@ -1586,8 +1599,30 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
||||
|
||||
def _run_anthropic_oauth_flow(save_env_value):
|
||||
"""Run the Claude OAuth setup-token flow. Returns True if credentials were saved."""
|
||||
from agent.anthropic_adapter import run_oauth_setup_token
|
||||
from hermes_cli.config import save_anthropic_oauth_token
|
||||
from agent.anthropic_adapter import (
|
||||
run_oauth_setup_token,
|
||||
read_claude_code_credentials,
|
||||
is_claude_code_token_valid,
|
||||
)
|
||||
from hermes_cli.config import (
|
||||
save_anthropic_oauth_token,
|
||||
use_anthropic_claude_code_credentials,
|
||||
)
|
||||
|
||||
def _activate_claude_code_credentials_if_available() -> bool:
|
||||
try:
|
||||
creds = read_claude_code_credentials()
|
||||
except Exception:
|
||||
creds = None
|
||||
if creds and (
|
||||
is_claude_code_token_valid(creds)
|
||||
or bool(creds.get("refreshToken"))
|
||||
):
|
||||
use_anthropic_claude_code_credentials(save_fn=save_env_value)
|
||||
print(" ✓ Claude Code credentials linked.")
|
||||
print(" Hermes will use Claude's credential store directly instead of copying a setup-token into ~/.hermes/.env.")
|
||||
return True
|
||||
return False
|
||||
|
||||
try:
|
||||
print()
|
||||
@@ -1596,6 +1631,8 @@ def _run_anthropic_oauth_flow(save_env_value):
|
||||
print()
|
||||
token = run_oauth_setup_token()
|
||||
if token:
|
||||
if _activate_claude_code_credentials_if_available():
|
||||
return True
|
||||
save_anthropic_oauth_token(token, save_fn=save_env_value)
|
||||
print(" ✓ OAuth credentials saved.")
|
||||
return True
|
||||
@@ -1828,6 +1865,18 @@ def cmd_version(args):
|
||||
except ImportError:
|
||||
print("OpenAI SDK: Not installed")
|
||||
|
||||
# Show update status (synchronous — acceptable since user asked for version info)
|
||||
try:
|
||||
from hermes_cli.banner import check_for_updates
|
||||
behind = check_for_updates()
|
||||
if behind and behind > 0:
|
||||
commits_word = "commit" if behind == 1 else "commits"
|
||||
print(f"Update available: {behind} {commits_word} behind — run 'hermes update'")
|
||||
elif behind == 0:
|
||||
print("Up to date")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def cmd_uninstall(args):
|
||||
"""Uninstall Hermes Agent."""
|
||||
@@ -1962,6 +2011,32 @@ def _stash_local_changes_if_needed(git_cmd: list[str], cwd: Path) -> Optional[st
|
||||
|
||||
|
||||
|
||||
def _resolve_stash_selector(git_cmd: list[str], cwd: Path, stash_ref: str) -> Optional[str]:
|
||||
stash_list = subprocess.run(
|
||||
git_cmd + ["stash", "list", "--format=%gd %H"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
for line in stash_list.stdout.splitlines():
|
||||
selector, _, commit = line.partition(" ")
|
||||
if commit.strip() == stash_ref:
|
||||
return selector.strip()
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def _print_stash_cleanup_guidance(stash_ref: str, stash_selector: Optional[str] = None) -> None:
|
||||
print(" Check `git status` first so you don't accidentally reapply the same change twice.")
|
||||
print(" Find the saved entry with: git stash list --format='%gd %H %s'")
|
||||
if stash_selector:
|
||||
print(f" Remove it with: git stash drop {stash_selector}")
|
||||
else:
|
||||
print(f" Look for commit {stash_ref}, then drop its selector with: git stash drop stash@{{N}}")
|
||||
|
||||
|
||||
|
||||
def _restore_stashed_changes(
|
||||
git_cmd: list[str],
|
||||
cwd: Path,
|
||||
@@ -1998,7 +2073,27 @@ def _restore_stashed_changes(
|
||||
print(f"Resolve manually with: git stash apply {stash_ref}")
|
||||
sys.exit(1)
|
||||
|
||||
subprocess.run(git_cmd + ["stash", "drop", stash_ref], cwd=cwd, check=True)
|
||||
stash_selector = _resolve_stash_selector(git_cmd, cwd, stash_ref)
|
||||
if stash_selector is None:
|
||||
print("⚠ Local changes were restored, but Hermes couldn't find the stash entry to drop.")
|
||||
print(" The stash was left in place. You can remove it manually after checking the result.")
|
||||
_print_stash_cleanup_guidance(stash_ref)
|
||||
else:
|
||||
drop = subprocess.run(
|
||||
git_cmd + ["stash", "drop", stash_selector],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if drop.returncode != 0:
|
||||
print("⚠ Local changes were restored, but Hermes couldn't drop the saved stash entry.")
|
||||
if drop.stdout.strip():
|
||||
print(drop.stdout.strip())
|
||||
if drop.stderr.strip():
|
||||
print(drop.stderr.strip())
|
||||
print(" The stash was left in place. You can remove it manually after checking the result.")
|
||||
_print_stash_cleanup_guidance(stash_ref, stash_selector)
|
||||
|
||||
print("⚠ Local changes were restored on top of the updated codebase.")
|
||||
print(" Review `git diff` / `git status` if Hermes behaves unexpectedly.")
|
||||
return True
|
||||
@@ -2056,7 +2151,15 @@ def cmd_update(args):
|
||||
check=True
|
||||
)
|
||||
branch = result.stdout.strip()
|
||||
|
||||
|
||||
# Fall back to main if the current branch doesn't exist on the remote
|
||||
verify = subprocess.run(
|
||||
git_cmd + ["rev-parse", "--verify", f"origin/{branch}"],
|
||||
cwd=PROJECT_ROOT, capture_output=True, text=True,
|
||||
)
|
||||
if verify.returncode != 0:
|
||||
branch = "main"
|
||||
|
||||
# Check if there are updates
|
||||
result = subprocess.run(
|
||||
git_cmd + ["rev-list", f"HEAD..origin/{branch}", "--count"],
|
||||
@@ -2268,8 +2371,9 @@ Examples:
|
||||
hermes config edit Edit config in $EDITOR
|
||||
hermes config set model gpt-4 Set a config value
|
||||
hermes gateway Run messaging gateway
|
||||
hermes -s hermes-agent-dev,github-auth
|
||||
hermes -w Start in isolated git worktree
|
||||
hermes gateway install Install as system service
|
||||
hermes gateway install Install gateway background service
|
||||
hermes sessions list List past sessions
|
||||
hermes sessions browse Interactive session picker
|
||||
hermes sessions rename ID T Rename/title a session
|
||||
@@ -2306,6 +2410,12 @@ For more help on a command:
|
||||
default=False,
|
||||
help="Run in an isolated git worktree (for parallel agents)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skills", "-s",
|
||||
action="append",
|
||||
default=None,
|
||||
help="Preload one or more skills for the session (repeat flag or comma-separate)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yolo",
|
||||
action="store_true",
|
||||
@@ -2341,6 +2451,12 @@ For more help on a command:
|
||||
"-t", "--toolsets",
|
||||
help="Comma-separated toolsets to enable"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"-s", "--skills",
|
||||
action="append",
|
||||
default=None,
|
||||
help="Preload one or more skills for the session (repeat flag or comma-separate)"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--provider",
|
||||
choices=["auto", "openrouter", "nous", "openai-codex", "anthropic", "zai", "kimi-coding", "minimax", "minimax-cn"],
|
||||
@@ -2425,23 +2541,30 @@ For more help on a command:
|
||||
|
||||
# gateway start
|
||||
gateway_start = gateway_subparsers.add_parser("start", help="Start gateway service")
|
||||
gateway_start.add_argument("--system", action="store_true", help="Target the Linux system-level gateway service")
|
||||
|
||||
# gateway stop
|
||||
gateway_stop = gateway_subparsers.add_parser("stop", help="Stop gateway service")
|
||||
gateway_stop.add_argument("--system", action="store_true", help="Target the Linux system-level gateway service")
|
||||
|
||||
# gateway restart
|
||||
gateway_restart = gateway_subparsers.add_parser("restart", help="Restart gateway service")
|
||||
gateway_restart.add_argument("--system", action="store_true", help="Target the Linux system-level gateway service")
|
||||
|
||||
# gateway status
|
||||
gateway_status = gateway_subparsers.add_parser("status", help="Show gateway status")
|
||||
gateway_status.add_argument("--deep", action="store_true", help="Deep status check")
|
||||
gateway_status.add_argument("--system", action="store_true", help="Target the Linux system-level gateway service")
|
||||
|
||||
# gateway install
|
||||
gateway_install = gateway_subparsers.add_parser("install", help="Install gateway as service")
|
||||
gateway_install.add_argument("--force", action="store_true", help="Force reinstall")
|
||||
gateway_install.add_argument("--system", action="store_true", help="Install as a Linux system-level service (starts at boot)")
|
||||
gateway_install.add_argument("--run-as-user", dest="run_as_user", help="User account the Linux system service should run as")
|
||||
|
||||
# gateway uninstall
|
||||
gateway_uninstall = gateway_subparsers.add_parser("uninstall", help="Uninstall gateway service")
|
||||
gateway_uninstall.add_argument("--system", action="store_true", help="Target the Linux system-level gateway service")
|
||||
|
||||
# gateway setup
|
||||
gateway_setup = gateway_subparsers.add_parser("setup", help="Configure messaging platforms")
|
||||
@@ -2589,13 +2712,48 @@ For more help on a command:
|
||||
# cron list
|
||||
cron_list = cron_subparsers.add_parser("list", help="List scheduled jobs")
|
||||
cron_list.add_argument("--all", action="store_true", help="Include disabled jobs")
|
||||
|
||||
|
||||
# cron create/add
|
||||
cron_create = cron_subparsers.add_parser("create", aliases=["add"], help="Create a scheduled job")
|
||||
cron_create.add_argument("schedule", help="Schedule like '30m', 'every 2h', or '0 9 * * *'")
|
||||
cron_create.add_argument("prompt", nargs="?", help="Optional self-contained prompt or task instruction")
|
||||
cron_create.add_argument("--name", help="Optional human-friendly job name")
|
||||
cron_create.add_argument("--deliver", help="Delivery target: origin, local, telegram, discord, signal, or platform:chat_id")
|
||||
cron_create.add_argument("--repeat", type=int, help="Optional repeat count")
|
||||
cron_create.add_argument("--skill", dest="skills", action="append", help="Attach a skill. Repeat to add multiple skills.")
|
||||
|
||||
# cron edit
|
||||
cron_edit = cron_subparsers.add_parser("edit", help="Edit an existing scheduled job")
|
||||
cron_edit.add_argument("job_id", help="Job ID to edit")
|
||||
cron_edit.add_argument("--schedule", help="New schedule")
|
||||
cron_edit.add_argument("--prompt", help="New prompt/task instruction")
|
||||
cron_edit.add_argument("--name", help="New job name")
|
||||
cron_edit.add_argument("--deliver", help="New delivery target")
|
||||
cron_edit.add_argument("--repeat", type=int, help="New repeat count")
|
||||
cron_edit.add_argument("--skill", dest="skills", action="append", help="Replace the job's skills with this set. Repeat to attach multiple skills.")
|
||||
cron_edit.add_argument("--add-skill", dest="add_skills", action="append", help="Append a skill without replacing the existing list. Repeatable.")
|
||||
cron_edit.add_argument("--remove-skill", dest="remove_skills", action="append", help="Remove a specific attached skill. Repeatable.")
|
||||
cron_edit.add_argument("--clear-skills", action="store_true", help="Remove all attached skills from the job")
|
||||
|
||||
# lifecycle actions
|
||||
cron_pause = cron_subparsers.add_parser("pause", help="Pause a scheduled job")
|
||||
cron_pause.add_argument("job_id", help="Job ID to pause")
|
||||
|
||||
cron_resume = cron_subparsers.add_parser("resume", help="Resume a paused job")
|
||||
cron_resume.add_argument("job_id", help="Job ID to resume")
|
||||
|
||||
cron_run = cron_subparsers.add_parser("run", help="Run a job on the next scheduler tick")
|
||||
cron_run.add_argument("job_id", help="Job ID to trigger")
|
||||
|
||||
cron_remove = cron_subparsers.add_parser("remove", aliases=["rm", "delete"], help="Remove a scheduled job")
|
||||
cron_remove.add_argument("job_id", help="Job ID to remove")
|
||||
|
||||
# cron status
|
||||
cron_subparsers.add_parser("status", help="Check if cron scheduler is running")
|
||||
|
||||
|
||||
# cron tick (mostly for debugging)
|
||||
cron_subparsers.add_parser("tick", help="Run due jobs once and exit")
|
||||
|
||||
|
||||
cron_parser.set_defaults(func=cmd_cron)
|
||||
|
||||
# =========================================================================
|
||||
@@ -2701,7 +2859,7 @@ For more help on a command:
|
||||
skills_install = skills_subparsers.add_parser("install", help="Install a skill")
|
||||
skills_install.add_argument("identifier", help="Skill identifier (e.g. openai/skills/skill-creator)")
|
||||
skills_install.add_argument("--category", default="", help="Category folder to install into")
|
||||
skills_install.add_argument("--force", action="store_true", help="Install despite caution verdict")
|
||||
skills_install.add_argument("--force", "--yes", "-y", dest="force", action="store_true", help="Install despite blocked scan verdict")
|
||||
|
||||
skills_inspect = skills_subparsers.add_parser("inspect", help="Preview a skill without installing")
|
||||
skills_inspect.add_argument("identifier", help="Skill identifier")
|
||||
@@ -2940,7 +3098,11 @@ For more help on a command:
|
||||
|
||||
elif action == "export":
|
||||
if args.session_id:
|
||||
data = db.export_session(args.session_id)
|
||||
resolved_session_id = db.resolve_session_id(args.session_id)
|
||||
if not resolved_session_id:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
return
|
||||
data = db.export_session(resolved_session_id)
|
||||
if not data:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
return
|
||||
@@ -2955,13 +3117,17 @@ For more help on a command:
|
||||
print(f"Exported {len(sessions)} sessions to {args.output}")
|
||||
|
||||
elif action == "delete":
|
||||
resolved_session_id = db.resolve_session_id(args.session_id)
|
||||
if not resolved_session_id:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
return
|
||||
if not args.yes:
|
||||
confirm = input(f"Delete session '{args.session_id}' and all its messages? [y/N] ")
|
||||
confirm = input(f"Delete session '{resolved_session_id}' and all its messages? [y/N] ")
|
||||
if confirm.lower() not in ("y", "yes"):
|
||||
print("Cancelled.")
|
||||
return
|
||||
if db.delete_session(args.session_id):
|
||||
print(f"Deleted session '{args.session_id}'.")
|
||||
if db.delete_session(resolved_session_id):
|
||||
print(f"Deleted session '{resolved_session_id}'.")
|
||||
else:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
|
||||
@@ -2977,10 +3143,14 @@ For more help on a command:
|
||||
print(f"Pruned {count} session(s).")
|
||||
|
||||
elif action == "rename":
|
||||
resolved_session_id = db.resolve_session_id(args.session_id)
|
||||
if not resolved_session_id:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
return
|
||||
title = " ".join(args.title)
|
||||
try:
|
||||
if db.set_session_title(args.session_id, title):
|
||||
print(f"Session '{args.session_id}' renamed to: {title}")
|
||||
if db.set_session_title(resolved_session_id, title):
|
||||
print(f"Session '{resolved_session_id}' renamed to: {title}")
|
||||
else:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
except ValueError as e:
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_cli import auth as auth_mod
|
||||
from hermes_cli.auth import (
|
||||
AuthError,
|
||||
PROVIDER_REGISTRY,
|
||||
@@ -18,6 +19,10 @@ from hermes_cli.config import load_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
|
||||
def _normalize_custom_provider_name(value: str) -> str:
|
||||
return value.strip().lower().replace(" ", "-")
|
||||
|
||||
|
||||
def _get_model_config() -> Dict[str, Any]:
|
||||
config = load_config()
|
||||
model_cfg = config.get("model")
|
||||
@@ -47,6 +52,82 @@ def resolve_requested_provider(requested: Optional[str] = None) -> str:
|
||||
return "auto"
|
||||
|
||||
|
||||
def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, Any]]:
|
||||
requested_norm = _normalize_custom_provider_name(requested_provider or "")
|
||||
if not requested_norm or requested_norm == "custom":
|
||||
return None
|
||||
|
||||
# Raw names should only map to custom providers when they are not already
|
||||
# valid built-in providers or aliases. Explicit menu keys like
|
||||
# ``custom:local`` always target the saved custom provider.
|
||||
if requested_norm == "auto":
|
||||
return None
|
||||
if not requested_norm.startswith("custom:"):
|
||||
try:
|
||||
auth_mod.resolve_provider(requested_norm)
|
||||
except AuthError:
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
|
||||
config = load_config()
|
||||
custom_providers = config.get("custom_providers")
|
||||
if not isinstance(custom_providers, list):
|
||||
return None
|
||||
|
||||
for entry in custom_providers:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
name = entry.get("name")
|
||||
base_url = entry.get("base_url")
|
||||
if not isinstance(name, str) or not isinstance(base_url, str):
|
||||
continue
|
||||
name_norm = _normalize_custom_provider_name(name)
|
||||
menu_key = f"custom:{name_norm}"
|
||||
if requested_norm not in {name_norm, menu_key}:
|
||||
continue
|
||||
return {
|
||||
"name": name.strip(),
|
||||
"base_url": base_url.strip(),
|
||||
"api_key": str(entry.get("api_key", "") or "").strip(),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_named_custom_runtime(
|
||||
*,
|
||||
requested_provider: str,
|
||||
explicit_api_key: Optional[str] = None,
|
||||
explicit_base_url: Optional[str] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
custom_provider = _get_named_custom_provider(requested_provider)
|
||||
if not custom_provider:
|
||||
return None
|
||||
|
||||
base_url = (
|
||||
(explicit_base_url or "").strip()
|
||||
or custom_provider.get("base_url", "")
|
||||
).rstrip("/")
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
api_key = (
|
||||
(explicit_api_key or "").strip()
|
||||
or custom_provider.get("api_key", "")
|
||||
or os.getenv("OPENAI_API_KEY", "").strip()
|
||||
or os.getenv("OPENROUTER_API_KEY", "").strip()
|
||||
)
|
||||
|
||||
return {
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": base_url,
|
||||
"api_key": api_key,
|
||||
"source": f"custom_provider:{custom_provider.get('name', requested_provider)}",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_openrouter_runtime(
|
||||
*,
|
||||
requested_provider: str,
|
||||
@@ -63,10 +144,16 @@ def _resolve_openrouter_runtime(
|
||||
env_openrouter_base_url = os.getenv("OPENROUTER_BASE_URL", "").strip()
|
||||
|
||||
use_config_base_url = False
|
||||
if requested_norm == "auto":
|
||||
if cfg_base_url.strip() and not explicit_base_url and not env_openai_base_url:
|
||||
if cfg_base_url.strip() and not explicit_base_url and not env_openai_base_url:
|
||||
if requested_norm == "auto":
|
||||
if not cfg_provider or cfg_provider == "auto":
|
||||
use_config_base_url = True
|
||||
elif requested_norm == "custom":
|
||||
# Persisted custom endpoints store their base URL in config.yaml.
|
||||
# If OPENAI_BASE_URL is not currently set in the environment, keep
|
||||
# honoring that saved endpoint instead of falling back to OpenRouter.
|
||||
if cfg_provider == "custom":
|
||||
use_config_base_url = True
|
||||
|
||||
# When the user explicitly requested the openrouter provider, skip
|
||||
# OPENAI_BASE_URL — it typically points to a custom / non-OpenRouter
|
||||
@@ -122,6 +209,15 @@ def resolve_runtime_provider(
|
||||
"""Resolve runtime provider credentials for agent execution."""
|
||||
requested_provider = resolve_requested_provider(requested)
|
||||
|
||||
custom_runtime = _resolve_named_custom_runtime(
|
||||
requested_provider=requested_provider,
|
||||
explicit_api_key=explicit_api_key,
|
||||
explicit_base_url=explicit_base_url,
|
||||
)
|
||||
if custom_runtime:
|
||||
custom_runtime["requested_provider"] = requested_provider
|
||||
return custom_runtime
|
||||
|
||||
provider = resolve_provider(
|
||||
requested_provider,
|
||||
explicit_api_key=explicit_api_key,
|
||||
|
||||
+121
-45
@@ -460,12 +460,23 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
|
||||
tool_status = []
|
||||
|
||||
# OpenRouter (required for vision, moa)
|
||||
if get_env_value("OPENROUTER_API_KEY"):
|
||||
# Vision — use the same runtime resolver as the actual vision tools
|
||||
try:
|
||||
from agent.auxiliary_client import get_available_vision_backends
|
||||
|
||||
_vision_backends = get_available_vision_backends()
|
||||
except Exception:
|
||||
_vision_backends = []
|
||||
|
||||
if _vision_backends:
|
||||
tool_status.append(("Vision (image analysis)", True, None))
|
||||
else:
|
||||
tool_status.append(("Vision (image analysis)", False, "run 'hermes setup' to configure"))
|
||||
|
||||
# Mixture of Agents — requires OpenRouter specifically (calls multiple models)
|
||||
if get_env_value("OPENROUTER_API_KEY"):
|
||||
tool_status.append(("Mixture of Agents", True, None))
|
||||
else:
|
||||
tool_status.append(("Vision (image analysis)", False, "OPENROUTER_API_KEY"))
|
||||
tool_status.append(("Mixture of Agents", False, "OPENROUTER_API_KEY"))
|
||||
|
||||
# Firecrawl (web tools)
|
||||
@@ -602,7 +613,7 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
print(
|
||||
f" {color('hermes config edit', Colors.GREEN)} Open config in your editor"
|
||||
)
|
||||
print(f" {color('hermes config set KEY VALUE', Colors.GREEN)}")
|
||||
print(f" {color('hermes config set <key> <value>', Colors.GREEN)}")
|
||||
print(f" Set a specific value")
|
||||
print()
|
||||
print(f" Or edit the files directly:")
|
||||
@@ -1246,35 +1257,84 @@ def setup_model_provider(config: dict):
|
||||
elif existing_or:
|
||||
selected_provider = "openrouter"
|
||||
|
||||
# ── OpenRouter API Key for tools (if not already set) ──
|
||||
# Tools (vision, web, MoA) use OpenRouter independently of the main provider.
|
||||
# Prompt for OpenRouter key if not set and a non-OpenRouter provider was chosen.
|
||||
if selected_provider in (
|
||||
"nous",
|
||||
"openai-codex",
|
||||
"custom",
|
||||
"zai",
|
||||
"kimi-coding",
|
||||
"minimax",
|
||||
"minimax-cn",
|
||||
"anthropic",
|
||||
) and not get_env_value("OPENROUTER_API_KEY"):
|
||||
print()
|
||||
print_header("OpenRouter API Key (for tools)")
|
||||
print_info("Tools like vision analysis, web search, and MoA use OpenRouter")
|
||||
print_info("independently of your main inference provider.")
|
||||
print_info("Get your API key at: https://openrouter.ai/keys")
|
||||
# ── Vision & Image Analysis Setup ──
|
||||
# Keep setup aligned with the actual runtime resolver the vision tools use.
|
||||
try:
|
||||
from agent.auxiliary_client import get_available_vision_backends
|
||||
|
||||
api_key = prompt(
|
||||
" OpenRouter API key (optional, press Enter to skip)", password=True
|
||||
)
|
||||
if api_key:
|
||||
save_env_value("OPENROUTER_API_KEY", api_key)
|
||||
print_success("OpenRouter API key saved (for tools)")
|
||||
_vision_backends = set(get_available_vision_backends())
|
||||
except Exception:
|
||||
_vision_backends = set()
|
||||
|
||||
_vision_needs_setup = not bool(_vision_backends)
|
||||
|
||||
if selected_provider in _vision_backends:
|
||||
# If the user just selected a backend Hermes can already use for
|
||||
# vision, treat it as covered. Auth/setup failure returns earlier.
|
||||
_vision_needs_setup = False
|
||||
|
||||
if _vision_needs_setup:
|
||||
_prov_names = {
|
||||
"nous-api": "Nous Portal API key",
|
||||
"zai": "Z.AI / GLM",
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax CN",
|
||||
"anthropic": "Anthropic",
|
||||
"custom": "your custom endpoint",
|
||||
}
|
||||
_prov_display = _prov_names.get(selected_provider, selected_provider or "your provider")
|
||||
|
||||
print()
|
||||
print_header("Vision & Image Analysis (optional)")
|
||||
print_info(f"Vision uses a separate multimodal backend. {_prov_display}")
|
||||
print_info("doesn't currently provide one Hermes can auto-use for vision,")
|
||||
print_info("so choose a backend now or skip and configure later.")
|
||||
print()
|
||||
|
||||
_vision_choices = [
|
||||
"OpenRouter — uses Gemini (free tier at openrouter.ai/keys)",
|
||||
"OpenAI-compatible endpoint — base URL, API key, and vision model",
|
||||
"Skip for now",
|
||||
]
|
||||
_vision_idx = prompt_choice("Configure vision:", _vision_choices, 2)
|
||||
|
||||
if _vision_idx == 0: # OpenRouter
|
||||
_or_key = prompt(" OpenRouter API key", password=True).strip()
|
||||
if _or_key:
|
||||
save_env_value("OPENROUTER_API_KEY", _or_key)
|
||||
print_success("OpenRouter key saved — vision will use Gemini")
|
||||
else:
|
||||
print_info("Skipped — vision won't be available")
|
||||
elif _vision_idx == 1: # OpenAI-compatible endpoint
|
||||
_base_url = prompt(" Base URL (blank for OpenAI)").strip() or "https://api.openai.com/v1"
|
||||
_api_key_label = " API key"
|
||||
if "api.openai.com" in _base_url.lower():
|
||||
_api_key_label = " OpenAI API key"
|
||||
_oai_key = prompt(_api_key_label, password=True).strip()
|
||||
if _oai_key:
|
||||
save_env_value("OPENAI_API_KEY", _oai_key)
|
||||
save_env_value("OPENAI_BASE_URL", _base_url)
|
||||
if "api.openai.com" in _base_url.lower():
|
||||
_oai_vision_models = ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano"]
|
||||
_vm_choices = _oai_vision_models + ["Use default (gpt-4o-mini)"]
|
||||
_vm_idx = prompt_choice("Select vision model:", _vm_choices, 0)
|
||||
_selected_vision_model = (
|
||||
_oai_vision_models[_vm_idx]
|
||||
if _vm_idx < len(_oai_vision_models)
|
||||
else "gpt-4o-mini"
|
||||
)
|
||||
else:
|
||||
_selected_vision_model = prompt(" Vision model (blank = use main/custom default)").strip()
|
||||
save_env_value("AUXILIARY_VISION_MODEL", _selected_vision_model)
|
||||
print_success(
|
||||
f"Vision configured with {_base_url}"
|
||||
+ (f" ({_selected_vision_model})" if _selected_vision_model else "")
|
||||
)
|
||||
else:
|
||||
print_info("Skipped — vision won't be available")
|
||||
else:
|
||||
print_info(
|
||||
"Skipped - some tools (vision, web scraping) won't work without this"
|
||||
)
|
||||
print_info("Skipped — add later with 'hermes setup' or configure AUXILIARY_VISION_* settings")
|
||||
|
||||
# ── Model Selection (adapts based on provider) ──
|
||||
if selected_provider != "custom": # Custom already prompted for model name
|
||||
@@ -2080,20 +2140,22 @@ def setup_gateway(config: dict):
|
||||
print_info(" • Create an App-Level Token with 'connections:write' scope")
|
||||
print_info(" 3. Add Bot Token Scopes: Features → OAuth & Permissions")
|
||||
print_info(" Required scopes: chat:write, app_mentions:read,")
|
||||
print_info(" channels:history, channels:read, groups:history,")
|
||||
print_info(" im:history, im:read, im:write, users:read, files:write")
|
||||
print_info(" channels:history, channels:read, im:history,")
|
||||
print_info(" im:read, im:write, users:read, files:write")
|
||||
print_info(" Optional for private channels: groups:history")
|
||||
print_info(" 4. Subscribe to Events: Features → Event Subscriptions → Enable")
|
||||
print_info(" Required events: message.im, message.channels,")
|
||||
print_info(" message.groups, app_mention")
|
||||
print_warning(" ⚠ Without message.channels/message.groups events,")
|
||||
print_warning(" the bot will ONLY work in DMs, not channels!")
|
||||
print_info(" Required events: message.im, message.channels, app_mention")
|
||||
print_info(" Optional for private channels: message.groups")
|
||||
print_warning(" ⚠ Without message.channels the bot will ONLY work in DMs,")
|
||||
print_warning(" not public channels.")
|
||||
print_info(" 5. Install to Workspace: Settings → Install App")
|
||||
print_info(" 6. Reinstall the app after any scope or event changes")
|
||||
print_info(
|
||||
" 6. After installing, invite the bot to channels: /invite @YourBot"
|
||||
" 7. After installing, invite the bot to channels: /invite @YourBot"
|
||||
)
|
||||
print()
|
||||
print_info(
|
||||
" Full guide: https://hermes-agent.ai/docs/user-guide/messaging/slack"
|
||||
" Full guide: https://hermes-agent.nousresearch.com/docs/user-guide/messaging/slack/"
|
||||
)
|
||||
print()
|
||||
bot_token = prompt("Slack Bot Token (xoxb-...)", password=True)
|
||||
@@ -2111,14 +2173,17 @@ def setup_gateway(config: dict):
|
||||
)
|
||||
print()
|
||||
allowed_users = prompt(
|
||||
"Allowed user IDs (comma-separated, leave empty for open access)"
|
||||
"Allowed user IDs (comma-separated, leave empty to deny everyone except paired users)"
|
||||
)
|
||||
if allowed_users:
|
||||
save_env_value("SLACK_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("Slack allowlist configured")
|
||||
else:
|
||||
print_warning(
|
||||
"⚠️ No Slack allowlist set - unpaired users will be denied by default."
|
||||
)
|
||||
print_info(
|
||||
"⚠️ No allowlist set - anyone in your workspace can use the bot!"
|
||||
" Set SLACK_ALLOW_ALL_USERS=true or GATEWAY_ALLOW_ALL_USERS=true only if you intentionally want open workspace access."
|
||||
)
|
||||
|
||||
# ── WhatsApp ──
|
||||
@@ -2178,7 +2243,9 @@ def setup_gateway(config: dict):
|
||||
from hermes_cli.gateway import (
|
||||
_is_service_installed,
|
||||
_is_service_running,
|
||||
systemd_install,
|
||||
has_conflicting_systemd_units,
|
||||
install_linux_gateway_from_setup,
|
||||
print_systemd_scope_conflict_warning,
|
||||
systemd_start,
|
||||
systemd_restart,
|
||||
launchd_install,
|
||||
@@ -2190,6 +2257,10 @@ def setup_gateway(config: dict):
|
||||
service_running = _is_service_running()
|
||||
|
||||
print()
|
||||
if _is_linux and has_conflicting_systemd_units():
|
||||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if service_running:
|
||||
if prompt_yes_no(" Restart the gateway to pick up changes?", True):
|
||||
try:
|
||||
@@ -2215,15 +2286,18 @@ def setup_gateway(config: dict):
|
||||
True,
|
||||
):
|
||||
try:
|
||||
installed_scope = None
|
||||
did_install = False
|
||||
if _is_linux:
|
||||
systemd_install(force=False)
|
||||
installed_scope, did_install = install_linux_gateway_from_setup(force=False)
|
||||
else:
|
||||
launchd_install(force=False)
|
||||
did_install = True
|
||||
print()
|
||||
if prompt_yes_no(" Start the service now?", True):
|
||||
if did_install and prompt_yes_no(" Start the service now?", True):
|
||||
try:
|
||||
if _is_linux:
|
||||
systemd_start()
|
||||
systemd_start(system=installed_scope == "system")
|
||||
elif _is_macos:
|
||||
launchd_start()
|
||||
except Exception as e:
|
||||
@@ -2233,6 +2307,8 @@ def setup_gateway(config: dict):
|
||||
print_info(" You can try manually: hermes gateway install")
|
||||
else:
|
||||
print_info(" You can install later: hermes gateway install")
|
||||
if _is_linux:
|
||||
print_info(" Or as a boot-time service: sudo hermes gateway install --system")
|
||||
print_info(" Or run in foreground: hermes gateway")
|
||||
else:
|
||||
print_info("Start the gateway to bring your bots online:")
|
||||
|
||||
@@ -1050,11 +1050,11 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
|
||||
|
||||
elif action == "install":
|
||||
if not args:
|
||||
c.print("[bold red]Usage:[/] /skills install <identifier> [--category <cat>] [--force]\n")
|
||||
c.print("[bold red]Usage:[/] /skills install <identifier> [--category <cat>] [--force|--yes]\n")
|
||||
return
|
||||
identifier = args[0]
|
||||
category = ""
|
||||
force = "--force" in args
|
||||
force = any(flag in args for flag in ("--force", "--yes", "-y"))
|
||||
for i, a in enumerate(args):
|
||||
if a == "--category" and i + 1 < len(args):
|
||||
category = args[i + 1]
|
||||
|
||||
@@ -91,7 +91,7 @@ CONFIGURABLE_TOOLSETS = [
|
||||
("session_search", "🔎 Session Search", "search past conversations"),
|
||||
("clarify", "❓ Clarifying Questions", "clarify"),
|
||||
("delegation", "👥 Task Delegation", "delegate_task"),
|
||||
("cronjob", "⏰ Cron Jobs", "schedule, list, remove"),
|
||||
("cronjob", "⏰ Cron Jobs", "create/list/update/pause/resume/run, with optional attached skills"),
|
||||
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
|
||||
("homeassistant", "🏠 Home Assistant", "smart home device control"),
|
||||
]
|
||||
@@ -354,22 +354,49 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
|
||||
|
||||
def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[str]):
|
||||
"""Save the selected toolset keys for a platform to config."""
|
||||
"""Save the selected toolset keys for a platform to config.
|
||||
|
||||
Preserves any non-configurable toolset entries (like MCP server names)
|
||||
that were already in the config for this platform.
|
||||
"""
|
||||
config.setdefault("platform_toolsets", {})
|
||||
config["platform_toolsets"][platform] = sorted(enabled_toolset_keys)
|
||||
|
||||
# Get the set of all configurable toolset keys
|
||||
configurable_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
|
||||
# Get existing toolsets for this platform
|
||||
existing_toolsets = config.get("platform_toolsets", {}).get(platform, [])
|
||||
if not isinstance(existing_toolsets, list):
|
||||
existing_toolsets = []
|
||||
|
||||
# Preserve any entries that are NOT configurable toolsets (i.e. MCP server names)
|
||||
preserved_entries = {
|
||||
entry for entry in existing_toolsets
|
||||
if entry not in configurable_keys
|
||||
}
|
||||
|
||||
# Merge preserved entries with new enabled toolsets
|
||||
config["platform_toolsets"][platform] = sorted(enabled_toolset_keys | preserved_entries)
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _toolset_has_keys(ts_key: str) -> bool:
|
||||
"""Check if a toolset's required API keys are configured."""
|
||||
if ts_key == "vision":
|
||||
try:
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
_provider, client, _model = resolve_vision_provider_client()
|
||||
return client is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Check TOOL_CATEGORIES first (provider-aware)
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
if cat:
|
||||
for provider in cat["providers"]:
|
||||
for provider in cat.get("providers", []):
|
||||
env_vars = provider.get("env_vars", [])
|
||||
if not env_vars:
|
||||
return True # Free provider (e.g., Edge TTS)
|
||||
if all(get_env_value(v["key"]) for v in env_vars):
|
||||
if env_vars and all(get_env_value(e["key"]) for e in env_vars):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -628,6 +655,39 @@ def _configure_provider(provider: dict, config: dict):
|
||||
|
||||
def _configure_simple_requirements(ts_key: str):
|
||||
"""Simple fallback for toolsets that just need env vars (no provider selection)."""
|
||||
if ts_key == "vision":
|
||||
if _toolset_has_keys("vision"):
|
||||
return
|
||||
print()
|
||||
print(color(" Vision / Image Analysis requires a multimodal backend:", Colors.YELLOW))
|
||||
choices = [
|
||||
"OpenRouter — uses Gemini",
|
||||
"OpenAI-compatible endpoint — base URL, API key, and vision model",
|
||||
"Skip",
|
||||
]
|
||||
idx = _prompt_choice(" Configure vision backend", choices, 2)
|
||||
if idx == 0:
|
||||
_print_info(" Get key at: https://openrouter.ai/keys")
|
||||
value = _prompt(" OPENROUTER_API_KEY", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value("OPENROUTER_API_KEY", value.strip())
|
||||
_print_success(" Saved")
|
||||
else:
|
||||
_print_warning(" Skipped")
|
||||
elif idx == 1:
|
||||
base_url = _prompt(" OPENAI_BASE_URL (blank for OpenAI)").strip() or "https://api.openai.com/v1"
|
||||
key_label = " OPENAI_API_KEY" if "api.openai.com" in base_url.lower() else " API key"
|
||||
api_key = _prompt(key_label, password=True)
|
||||
if api_key and api_key.strip():
|
||||
save_env_value("OPENAI_BASE_URL", base_url)
|
||||
save_env_value("OPENAI_API_KEY", api_key.strip())
|
||||
if "api.openai.com" in base_url.lower():
|
||||
save_env_value("AUXILIARY_VISION_MODEL", "gpt-4o-mini")
|
||||
_print_success(" Saved")
|
||||
else:
|
||||
_print_warning(" Skipped")
|
||||
return
|
||||
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
@@ -249,6 +249,32 @@ class SessionDB:
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]:
|
||||
"""Resolve an exact or uniquely prefixed session ID to the full ID.
|
||||
|
||||
Returns the exact ID when it exists. Otherwise treats the input as a
|
||||
prefix and returns the single matching session ID if the prefix is
|
||||
unambiguous. Returns None for no matches or ambiguous prefixes.
|
||||
"""
|
||||
exact = self.get_session(session_id_or_prefix)
|
||||
if exact:
|
||||
return exact["id"]
|
||||
|
||||
escaped = (
|
||||
session_id_or_prefix
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE id LIKE ? ESCAPE '\\' ORDER BY started_at DESC LIMIT 2",
|
||||
(f"{escaped}%",),
|
||||
)
|
||||
matches = [row["id"] for row in cursor.fetchall()]
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
return None
|
||||
|
||||
# Maximum length for session titles
|
||||
MAX_TITLE_LENGTH = 100
|
||||
|
||||
|
||||
+1
-1
@@ -144,7 +144,7 @@ _LEGACY_TOOLSET_MAP = {
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
],
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
|
||||
"cronjob_tools": ["cronjob"],
|
||||
"rl_tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
|
||||
@@ -27,25 +27,16 @@ from pathlib import Path
|
||||
import fire
|
||||
import yaml
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
_hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
_user_env = _hermes_home / ".env"
|
||||
_project_env = Path(__file__).parent / '.env'
|
||||
|
||||
if _user_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="latin-1")
|
||||
print(f"✅ Loaded environment variables from {_user_env}")
|
||||
elif _project_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_project_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_project_env, encoding="latin-1")
|
||||
print(f"✅ Loaded environment variables from {_project_env}")
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
_loaded_env_paths = load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
for _env_path in _loaded_env_paths:
|
||||
print(f"✅ Loaded environment variables from {_env_path}")
|
||||
|
||||
# Set terminal working directory to tinker-atropos submodule
|
||||
# This ensures terminal commands run in the right context for RL work
|
||||
|
||||
+390
-80
@@ -21,6 +21,8 @@ Usage:
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import hashlib
|
||||
@@ -31,6 +33,7 @@ import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import threading
|
||||
import weakref
|
||||
@@ -42,24 +45,16 @@ import fire
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback
|
||||
from dotenv import load_dotenv
|
||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
||||
# User-managed env files should override stale shell exports on restart.
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
_hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
_user_env = _hermes_home / ".env"
|
||||
_project_env = Path(__file__).parent / '.env'
|
||||
if _user_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_user_env, encoding="latin-1")
|
||||
logger.info("Loaded environment variables from %s", _user_env)
|
||||
elif _project_env.exists():
|
||||
try:
|
||||
load_dotenv(dotenv_path=_project_env, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(dotenv_path=_project_env, encoding="latin-1")
|
||||
logger.info("Loaded environment variables from %s", _project_env)
|
||||
_loaded_env_paths = load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
if _loaded_env_paths:
|
||||
for _env_path in _loaded_env_paths:
|
||||
logger.info("Loaded environment variables from %s", _env_path)
|
||||
else:
|
||||
logger.info("No .env file found. Using system environment variables.")
|
||||
|
||||
@@ -377,6 +372,7 @@ class AIAgent:
|
||||
# Interrupt mechanism for breaking out of tool loops
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None # Optional message that triggered interrupt
|
||||
self._client_lock = threading.RLock()
|
||||
|
||||
# Subagent delegation state
|
||||
self._delegate_depth = 0 # 0 = top-level agent, incremented for children
|
||||
@@ -503,6 +499,11 @@ class AIAgent:
|
||||
self._persist_user_message_idx = None
|
||||
self._persist_user_message_override = None
|
||||
|
||||
# Cache anthropic image-to-text fallbacks per image payload/URL so a
|
||||
# single tool loop does not repeatedly re-run auxiliary vision on the
|
||||
# same image history.
|
||||
self._anthropic_image_fallback_cache: Dict[str, str] = {}
|
||||
|
||||
# Initialize LLM client via centralized provider router.
|
||||
# The router handles auth resolution, base URL, headers, and
|
||||
# Codex/Anthropic wrapping for all known providers.
|
||||
@@ -566,7 +567,7 @@ class AIAgent:
|
||||
|
||||
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
|
||||
try:
|
||||
self.client = OpenAI(**client_kwargs)
|
||||
self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True)
|
||||
if not self.quiet_mode:
|
||||
print(f"🤖 AI Agent initialized with model: {self.model}")
|
||||
if base_url:
|
||||
@@ -2406,7 +2407,7 @@ class AIAgent:
|
||||
fn_name = getattr(item, "name", "") or ""
|
||||
arguments = getattr(item, "arguments", "{}")
|
||||
if not isinstance(arguments, str):
|
||||
arguments = str(arguments)
|
||||
arguments = json.dumps(arguments, ensure_ascii=False)
|
||||
raw_call_id = getattr(item, "call_id", None)
|
||||
raw_item_id = getattr(item, "id", None)
|
||||
embedded_call_id, _ = self._split_responses_tool_id(raw_item_id)
|
||||
@@ -2427,7 +2428,7 @@ class AIAgent:
|
||||
fn_name = getattr(item, "name", "") or ""
|
||||
arguments = getattr(item, "input", "{}")
|
||||
if not isinstance(arguments, str):
|
||||
arguments = str(arguments)
|
||||
arguments = json.dumps(arguments, ensure_ascii=False)
|
||||
raw_call_id = getattr(item, "call_id", None)
|
||||
raw_item_id = getattr(item, "id", None)
|
||||
embedded_call_id, _ = self._split_responses_tool_id(raw_item_id)
|
||||
@@ -2468,12 +2469,118 @@ class AIAgent:
|
||||
finish_reason = "stop"
|
||||
return assistant_message, finish_reason
|
||||
|
||||
def _run_codex_stream(self, api_kwargs: dict):
|
||||
def _thread_identity(self) -> str:
|
||||
thread = threading.current_thread()
|
||||
return f"{thread.name}:{thread.ident}"
|
||||
|
||||
def _client_log_context(self) -> str:
|
||||
provider = getattr(self, "provider", "unknown")
|
||||
base_url = getattr(self, "base_url", "unknown")
|
||||
model = getattr(self, "model", "unknown")
|
||||
return (
|
||||
f"thread={self._thread_identity()} provider={provider} "
|
||||
f"base_url={base_url} model={model}"
|
||||
)
|
||||
|
||||
def _openai_client_lock(self) -> threading.RLock:
|
||||
lock = getattr(self, "_client_lock", None)
|
||||
if lock is None:
|
||||
lock = threading.RLock()
|
||||
self._client_lock = lock
|
||||
return lock
|
||||
|
||||
@staticmethod
|
||||
def _is_openai_client_closed(client: Any) -> bool:
|
||||
from unittest.mock import Mock
|
||||
|
||||
if isinstance(client, Mock):
|
||||
return False
|
||||
http_client = getattr(client, "_client", None)
|
||||
return bool(getattr(http_client, "is_closed", False))
|
||||
|
||||
def _create_openai_client(self, client_kwargs: dict, *, reason: str, shared: bool) -> Any:
|
||||
client = OpenAI(**client_kwargs)
|
||||
logger.info(
|
||||
"OpenAI client created (%s, shared=%s) %s",
|
||||
reason,
|
||||
shared,
|
||||
self._client_log_context(),
|
||||
)
|
||||
return client
|
||||
|
||||
def _close_openai_client(self, client: Any, *, reason: str, shared: bool) -> None:
|
||||
if client is None:
|
||||
return
|
||||
try:
|
||||
client.close()
|
||||
logger.info(
|
||||
"OpenAI client closed (%s, shared=%s) %s",
|
||||
reason,
|
||||
shared,
|
||||
self._client_log_context(),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"OpenAI client close failed (%s, shared=%s) %s error=%s",
|
||||
reason,
|
||||
shared,
|
||||
self._client_log_context(),
|
||||
exc,
|
||||
)
|
||||
|
||||
def _replace_primary_openai_client(self, *, reason: str) -> bool:
|
||||
with self._openai_client_lock():
|
||||
old_client = getattr(self, "client", None)
|
||||
try:
|
||||
new_client = self._create_openai_client(self._client_kwargs, reason=reason, shared=True)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to rebuild shared OpenAI client (%s) %s error=%s",
|
||||
reason,
|
||||
self._client_log_context(),
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
self.client = new_client
|
||||
self._close_openai_client(old_client, reason=f"replace:{reason}", shared=True)
|
||||
return True
|
||||
|
||||
def _ensure_primary_openai_client(self, *, reason: str) -> Any:
|
||||
with self._openai_client_lock():
|
||||
client = getattr(self, "client", None)
|
||||
if client is not None and not self._is_openai_client_closed(client):
|
||||
return client
|
||||
|
||||
logger.warning(
|
||||
"Detected closed shared OpenAI client; recreating before use (%s) %s",
|
||||
reason,
|
||||
self._client_log_context(),
|
||||
)
|
||||
if not self._replace_primary_openai_client(reason=f"recreate_closed:{reason}"):
|
||||
raise RuntimeError("Failed to recreate closed OpenAI client")
|
||||
with self._openai_client_lock():
|
||||
return self.client
|
||||
|
||||
def _create_request_openai_client(self, *, reason: str) -> Any:
|
||||
from unittest.mock import Mock
|
||||
|
||||
primary_client = self._ensure_primary_openai_client(reason=reason)
|
||||
if isinstance(primary_client, Mock):
|
||||
return primary_client
|
||||
with self._openai_client_lock():
|
||||
request_kwargs = dict(self._client_kwargs)
|
||||
return self._create_openai_client(request_kwargs, reason=reason, shared=False)
|
||||
|
||||
def _close_request_openai_client(self, client: Any, *, reason: str) -> None:
|
||||
self._close_openai_client(client, reason=reason, shared=False)
|
||||
|
||||
def _run_codex_stream(self, api_kwargs: dict, client: Any = None):
|
||||
"""Execute one streaming Responses API request and return the final response."""
|
||||
active_client = client or self._ensure_primary_openai_client(reason="codex_stream_direct")
|
||||
max_stream_retries = 1
|
||||
for attempt in range(max_stream_retries + 1):
|
||||
try:
|
||||
with self.client.responses.stream(**api_kwargs) as stream:
|
||||
with active_client.responses.stream(**api_kwargs) as stream:
|
||||
for _ in stream:
|
||||
pass
|
||||
return stream.get_final_response()
|
||||
@@ -2482,24 +2589,27 @@ class AIAgent:
|
||||
missing_completed = "response.completed" in err_text
|
||||
if missing_completed and attempt < max_stream_retries:
|
||||
logger.debug(
|
||||
"Responses stream closed before completion (attempt %s/%s); retrying.",
|
||||
"Responses stream closed before completion (attempt %s/%s); retrying. %s",
|
||||
attempt + 1,
|
||||
max_stream_retries + 1,
|
||||
self._client_log_context(),
|
||||
)
|
||||
continue
|
||||
if missing_completed:
|
||||
logger.debug(
|
||||
"Responses stream did not emit response.completed; falling back to create(stream=True)."
|
||||
"Responses stream did not emit response.completed; falling back to create(stream=True). %s",
|
||||
self._client_log_context(),
|
||||
)
|
||||
return self._run_codex_create_stream_fallback(api_kwargs)
|
||||
return self._run_codex_create_stream_fallback(api_kwargs, client=active_client)
|
||||
raise
|
||||
|
||||
def _run_codex_create_stream_fallback(self, api_kwargs: dict):
|
||||
def _run_codex_create_stream_fallback(self, api_kwargs: dict, client: Any = None):
|
||||
"""Fallback path for stream completion edge cases on Codex-style Responses backends."""
|
||||
active_client = client or self._ensure_primary_openai_client(reason="codex_create_stream_fallback")
|
||||
fallback_kwargs = dict(api_kwargs)
|
||||
fallback_kwargs["stream"] = True
|
||||
fallback_kwargs = self._preflight_codex_api_kwargs(fallback_kwargs, allow_stream=True)
|
||||
stream_or_response = self.client.responses.create(**fallback_kwargs)
|
||||
stream_or_response = active_client.responses.create(**fallback_kwargs)
|
||||
|
||||
# Compatibility shim for mocks or providers that still return a concrete response.
|
||||
if hasattr(stream_or_response, "output"):
|
||||
@@ -2557,15 +2667,7 @@ class AIAgent:
|
||||
self._client_kwargs["api_key"] = self.api_key
|
||||
self._client_kwargs["base_url"] = self.base_url
|
||||
|
||||
try:
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to rebuild OpenAI client after Codex refresh: %s", exc)
|
||||
if not self._replace_primary_openai_client(reason="codex_credential_refresh"):
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -2600,61 +2702,101 @@ class AIAgent:
|
||||
# Nous requests should not inherit OpenRouter-only attribution headers.
|
||||
self._client_kwargs.pop("default_headers", None)
|
||||
|
||||
if not self._replace_primary_openai_client(reason="nous_credential_refresh"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _try_refresh_anthropic_client_credentials(self) -> bool:
|
||||
if self.api_mode != "anthropic_messages" or not hasattr(self, "_anthropic_api_key"):
|
||||
return False
|
||||
|
||||
try:
|
||||
self.client.close()
|
||||
from agent.anthropic_adapter import resolve_anthropic_token, build_anthropic_client
|
||||
|
||||
new_token = resolve_anthropic_token()
|
||||
except Exception as exc:
|
||||
logger.debug("Anthropic credential refresh failed: %s", exc)
|
||||
return False
|
||||
|
||||
if not isinstance(new_token, str) or not new_token.strip():
|
||||
return False
|
||||
new_token = new_token.strip()
|
||||
if new_token == self._anthropic_api_key:
|
||||
return False
|
||||
|
||||
try:
|
||||
self._anthropic_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
self._anthropic_client = build_anthropic_client(new_token, getattr(self, "_anthropic_base_url", None))
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to rebuild OpenAI client after Nous refresh: %s", exc)
|
||||
logger.warning("Failed to rebuild Anthropic client after credential refresh: %s", exc)
|
||||
return False
|
||||
|
||||
self._anthropic_api_key = new_token
|
||||
return True
|
||||
|
||||
def _anthropic_messages_create(self, api_kwargs: dict):
|
||||
if self.api_mode == "anthropic_messages":
|
||||
self._try_refresh_anthropic_client_credentials()
|
||||
return self._anthropic_client.messages.create(**api_kwargs)
|
||||
|
||||
def _interruptible_api_call(self, api_kwargs: dict):
|
||||
"""
|
||||
Run the API call in a background thread so the main conversation loop
|
||||
can detect interrupts without waiting for the full HTTP round-trip.
|
||||
|
||||
On interrupt, closes the HTTP client to cancel the in-flight request
|
||||
(stops token generation and avoids wasting money), then rebuilds the
|
||||
client for future calls.
|
||||
Each worker thread gets its own OpenAI client instance. Interrupts only
|
||||
close that worker-local client, so retries and other requests never
|
||||
inherit a closed transport.
|
||||
"""
|
||||
result = {"response": None, "error": None}
|
||||
request_client_holder = {"client": None}
|
||||
|
||||
def _call():
|
||||
try:
|
||||
if self.api_mode == "codex_responses":
|
||||
result["response"] = self._run_codex_stream(api_kwargs)
|
||||
request_client_holder["client"] = self._create_request_openai_client(reason="codex_stream_request")
|
||||
result["response"] = self._run_codex_stream(
|
||||
api_kwargs,
|
||||
client=request_client_holder["client"],
|
||||
)
|
||||
elif self.api_mode == "anthropic_messages":
|
||||
result["response"] = self._anthropic_client.messages.create(**api_kwargs)
|
||||
result["response"] = self._anthropic_messages_create(api_kwargs)
|
||||
else:
|
||||
result["response"] = self.client.chat.completions.create(**api_kwargs)
|
||||
request_client_holder["client"] = self._create_request_openai_client(reason="chat_completion_request")
|
||||
result["response"] = request_client_holder["client"].chat.completions.create(**api_kwargs)
|
||||
except Exception as e:
|
||||
result["error"] = e
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
self._close_request_openai_client(request_client, reason="request_complete")
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.3)
|
||||
if self._interrupt_requested:
|
||||
# Force-close the HTTP connection to stop token generation
|
||||
try:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
self._anthropic_client.close()
|
||||
else:
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
# Rebuild the client for future calls (cheap, no network)
|
||||
# Force-close the in-flight worker-local HTTP connection to stop
|
||||
# token generation without poisoning the shared client used to
|
||||
# seed future retries.
|
||||
try:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client
|
||||
self._anthropic_client = build_anthropic_client(self._anthropic_api_key, getattr(self, "_anthropic_base_url", None))
|
||||
|
||||
self._anthropic_client.close()
|
||||
self._anthropic_client = build_anthropic_client(
|
||||
self._anthropic_api_key,
|
||||
getattr(self, "_anthropic_base_url", None),
|
||||
)
|
||||
else:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
self._close_request_openai_client(request_client, reason="interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during API call")
|
||||
@@ -2673,11 +2815,15 @@ class AIAgent:
|
||||
core agent loop untouched for non-voice users.
|
||||
"""
|
||||
result = {"response": None, "error": None}
|
||||
request_client_holder = {"client": None}
|
||||
|
||||
def _call():
|
||||
try:
|
||||
stream_kwargs = {**api_kwargs, "stream": True}
|
||||
stream = self.client.chat.completions.create(**stream_kwargs)
|
||||
request_client_holder["client"] = self._create_request_openai_client(
|
||||
reason="chat_completion_stream_request"
|
||||
)
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts: list[str] = []
|
||||
tool_calls_acc: dict[int, dict] = {}
|
||||
@@ -2768,25 +2914,29 @@ class AIAgent:
|
||||
|
||||
except Exception as e:
|
||||
result["error"] = e
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
self._close_request_openai_client(request_client, reason="stream_request_complete")
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.3)
|
||||
if self._interrupt_requested:
|
||||
try:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
self._anthropic_client.close()
|
||||
else:
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client
|
||||
self._anthropic_client = build_anthropic_client(self._anthropic_api_key, getattr(self, "_anthropic_base_url", None))
|
||||
|
||||
self._anthropic_client.close()
|
||||
self._anthropic_client = build_anthropic_client(
|
||||
self._anthropic_api_key,
|
||||
getattr(self, "_anthropic_base_url", None),
|
||||
)
|
||||
else:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
self._close_request_openai_client(request_client, reason="stream_interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during API call")
|
||||
@@ -2884,13 +3034,156 @@ class AIAgent:
|
||||
|
||||
# ── End provider fallback ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _content_has_image_parts(content: Any) -> bool:
|
||||
if not isinstance(content, list):
|
||||
return False
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") in {"image_url", "input_image"}:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _materialize_data_url_for_vision(image_url: str) -> tuple[str, Optional[Path]]:
|
||||
header, _, data = str(image_url or "").partition(",")
|
||||
mime = "image/jpeg"
|
||||
if header.startswith("data:"):
|
||||
mime_part = header[len("data:"):].split(";", 1)[0].strip()
|
||||
if mime_part.startswith("image/"):
|
||||
mime = mime_part
|
||||
suffix = {
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/jpg": ".jpg",
|
||||
}.get(mime, ".jpg")
|
||||
tmp = tempfile.NamedTemporaryFile(prefix="anthropic_image_", suffix=suffix, delete=False)
|
||||
with tmp:
|
||||
tmp.write(base64.b64decode(data))
|
||||
path = Path(tmp.name)
|
||||
return str(path), path
|
||||
|
||||
def _describe_image_for_anthropic_fallback(self, image_url: str, role: str) -> str:
|
||||
cache_key = hashlib.sha256(str(image_url or "").encode("utf-8")).hexdigest()
|
||||
cached = self._anthropic_image_fallback_cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
role_label = {
|
||||
"assistant": "assistant",
|
||||
"tool": "tool result",
|
||||
}.get(role, "user")
|
||||
analysis_prompt = (
|
||||
"Describe everything visible in this image in thorough detail. "
|
||||
"Include any text, code, UI, data, objects, people, layout, colors, "
|
||||
"and any other notable visual information."
|
||||
)
|
||||
|
||||
vision_source = str(image_url or "")
|
||||
cleanup_path: Optional[Path] = None
|
||||
if vision_source.startswith("data:"):
|
||||
vision_source, cleanup_path = self._materialize_data_url_for_vision(vision_source)
|
||||
|
||||
description = ""
|
||||
try:
|
||||
from tools.vision_tools import vision_analyze_tool
|
||||
|
||||
result_json = asyncio.run(
|
||||
vision_analyze_tool(image_url=vision_source, user_prompt=analysis_prompt)
|
||||
)
|
||||
result = json.loads(result_json) if isinstance(result_json, str) else {}
|
||||
description = (result.get("analysis") or "").strip()
|
||||
except Exception as e:
|
||||
description = f"Image analysis failed: {e}"
|
||||
finally:
|
||||
if cleanup_path and cleanup_path.exists():
|
||||
try:
|
||||
cleanup_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if not description:
|
||||
description = "Image analysis failed."
|
||||
|
||||
note = f"[The {role_label} attached an image. Here's what it contains:\n{description}]"
|
||||
if vision_source and not str(image_url or "").startswith("data:"):
|
||||
note += (
|
||||
f"\n[If you need a closer look, use vision_analyze with image_url: {vision_source}]"
|
||||
)
|
||||
|
||||
self._anthropic_image_fallback_cache[cache_key] = note
|
||||
return note
|
||||
|
||||
def _preprocess_anthropic_content(self, content: Any, role: str) -> Any:
|
||||
if not self._content_has_image_parts(content):
|
||||
return content
|
||||
|
||||
text_parts: List[str] = []
|
||||
image_notes: List[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
if part.strip():
|
||||
text_parts.append(part.strip())
|
||||
continue
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
|
||||
ptype = part.get("type")
|
||||
if ptype in {"text", "input_text"}:
|
||||
text = str(part.get("text", "") or "").strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
continue
|
||||
|
||||
if ptype in {"image_url", "input_image"}:
|
||||
image_data = part.get("image_url", {})
|
||||
image_url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data or "")
|
||||
if image_url:
|
||||
image_notes.append(self._describe_image_for_anthropic_fallback(image_url, role))
|
||||
else:
|
||||
image_notes.append("[An image was attached but no image source was available.]")
|
||||
continue
|
||||
|
||||
text = str(part.get("text", "") or "").strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
|
||||
prefix = "\n\n".join(note for note in image_notes if note).strip()
|
||||
suffix = "\n".join(text for text in text_parts if text).strip()
|
||||
if prefix and suffix:
|
||||
return f"{prefix}\n\n{suffix}"
|
||||
if prefix:
|
||||
return prefix
|
||||
if suffix:
|
||||
return suffix
|
||||
return "[A multimodal message was converted to text for Anthropic compatibility.]"
|
||||
|
||||
def _prepare_anthropic_messages_for_api(self, api_messages: list) -> list:
|
||||
if not any(
|
||||
isinstance(msg, dict) and self._content_has_image_parts(msg.get("content"))
|
||||
for msg in api_messages
|
||||
):
|
||||
return api_messages
|
||||
|
||||
transformed = copy.deepcopy(api_messages)
|
||||
for msg in transformed:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
msg["content"] = self._preprocess_anthropic_content(
|
||||
msg.get("content"),
|
||||
str(msg.get("role", "user") or "user"),
|
||||
)
|
||||
return transformed
|
||||
|
||||
def _build_api_kwargs(self, api_messages: list) -> dict:
|
||||
"""Build the keyword arguments dict for the active API mode."""
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
anthropic_messages = self._prepare_anthropic_messages_for_api(api_messages)
|
||||
return build_anthropic_kwargs(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
messages=anthropic_messages,
|
||||
tools=self.tools,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_config=self.reasoning_config,
|
||||
@@ -3267,7 +3560,7 @@ class AIAgent:
|
||||
tools=[memory_tool_def], max_tokens=5120,
|
||||
reasoning_config=None,
|
||||
)
|
||||
response = self._anthropic_client.messages.create(**ant_kwargs)
|
||||
response = self._anthropic_messages_create(ant_kwargs)
|
||||
elif not _aux_available:
|
||||
api_kwargs = {
|
||||
"model": self.model,
|
||||
@@ -3276,7 +3569,7 @@ class AIAgent:
|
||||
"temperature": 0.3,
|
||||
**self._max_tokens_param(5120),
|
||||
}
|
||||
response = self.client.chat.completions.create(**api_kwargs, timeout=30.0)
|
||||
response = self._ensure_primary_openai_client(reason="flush_memories").chat.completions.create(**api_kwargs, timeout=30.0)
|
||||
|
||||
# Extract tool calls from the response, handling all API formats
|
||||
tool_calls = []
|
||||
@@ -3804,7 +4097,7 @@ class AIAgent:
|
||||
'image_generate': '🎨', 'text_to_speech': '🔊',
|
||||
'vision_analyze': '👁️', 'mixture_of_agents': '🧠',
|
||||
'skills_list': '📚', 'skill_view': '📚',
|
||||
'schedule_cronjob': '⏰', 'list_cronjobs': '⏰', 'remove_cronjob': '⏰',
|
||||
'cronjob': '⏰',
|
||||
'send_message': '📨', 'todo': '📋', 'memory': '🧠', 'session_search': '🔍',
|
||||
'clarify': '❓', 'execute_code': '🐍', 'delegate_task': '🔀',
|
||||
}
|
||||
@@ -4018,11 +4311,11 @@ class AIAgent:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs as _bak, normalize_anthropic_response as _nar
|
||||
_ant_kw = _bak(model=self.model, messages=api_messages, tools=None,
|
||||
max_tokens=self.max_tokens, reasoning_config=self.reasoning_config)
|
||||
summary_response = self._anthropic_client.messages.create(**_ant_kw)
|
||||
summary_response = self._anthropic_messages_create(_ant_kw)
|
||||
_msg, _ = _nar(summary_response)
|
||||
final_response = (_msg.content or "").strip()
|
||||
else:
|
||||
summary_response = self.client.chat.completions.create(**summary_kwargs)
|
||||
summary_response = self._ensure_primary_openai_client(reason="iteration_limit_summary").chat.completions.create(**summary_kwargs)
|
||||
|
||||
if summary_response.choices and summary_response.choices[0].message.content:
|
||||
final_response = summary_response.choices[0].message.content
|
||||
@@ -4048,7 +4341,7 @@ class AIAgent:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs as _bak2, normalize_anthropic_response as _nar2
|
||||
_ant_kw2 = _bak2(model=self.model, messages=api_messages, tools=None,
|
||||
max_tokens=self.max_tokens, reasoning_config=self.reasoning_config)
|
||||
retry_response = self._anthropic_client.messages.create(**_ant_kw2)
|
||||
retry_response = self._anthropic_messages_create(_ant_kw2)
|
||||
_retry_msg, _ = _nar2(retry_response)
|
||||
final_response = (_retry_msg.content or "").strip()
|
||||
else:
|
||||
@@ -4061,7 +4354,7 @@ class AIAgent:
|
||||
if summary_extra_body:
|
||||
summary_kwargs["extra_body"] = summary_extra_body
|
||||
|
||||
summary_response = self.client.chat.completions.create(**summary_kwargs)
|
||||
summary_response = self._ensure_primary_openai_client(reason="iteration_limit_summary_retry").chat.completions.create(**summary_kwargs)
|
||||
|
||||
if summary_response.choices and summary_response.choices[0].message.content:
|
||||
final_response = summary_response.choices[0].message.content
|
||||
@@ -4822,12 +5115,8 @@ class AIAgent:
|
||||
and not anthropic_auth_retry_attempted
|
||||
):
|
||||
anthropic_auth_retry_attempted = True
|
||||
# Try re-reading Claude Code credentials (they may have been refreshed)
|
||||
from agent.anthropic_adapter import resolve_anthropic_token, build_anthropic_client, _is_oauth_token
|
||||
new_token = resolve_anthropic_token()
|
||||
if new_token and new_token != self._anthropic_api_key:
|
||||
self._anthropic_api_key = new_token
|
||||
self._anthropic_client = build_anthropic_client(new_token, getattr(self, "_anthropic_base_url", None))
|
||||
from agent.anthropic_adapter import _is_oauth_token
|
||||
if self._try_refresh_anthropic_client_credentials():
|
||||
print(f"{self.log_prefix}🔐 Anthropic credentials refreshed after 401. Retrying request...")
|
||||
continue
|
||||
# Credential refresh didn't help — show diagnostic info
|
||||
@@ -4850,7 +5139,15 @@ class AIAgent:
|
||||
# Enhanced error logging
|
||||
error_type = type(api_error).__name__
|
||||
error_msg = str(api_error).lower()
|
||||
|
||||
logger.warning(
|
||||
"API call failed (attempt %s/%s) error_type=%s %s error=%s",
|
||||
retry_count,
|
||||
max_retries,
|
||||
error_type,
|
||||
self._client_log_context(),
|
||||
api_error,
|
||||
)
|
||||
|
||||
self._vprint(f"{self.log_prefix}⚠️ API call failed (attempt {retry_count}/{max_retries}): {error_type}", force=True)
|
||||
self._vprint(f"{self.log_prefix} ⏱️ Time elapsed before failure: {elapsed_time:.2f}s")
|
||||
self._vprint(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}", force=True)
|
||||
@@ -5040,7 +5337,14 @@ class AIAgent:
|
||||
raise api_error
|
||||
|
||||
wait_time = min(2 ** retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s
|
||||
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
||||
logger.warning(
|
||||
"Retrying API call in %ss (attempt %s/%s) %s error=%s",
|
||||
wait_time,
|
||||
retry_count,
|
||||
max_retries,
|
||||
self._client_log_context(),
|
||||
api_error,
|
||||
)
|
||||
if retry_count >= max_retries:
|
||||
self._vprint(f"{self.log_prefix}⚠️ API call failed after {retry_count} attempts: {str(api_error)[:100]}")
|
||||
self._vprint(f"{self.log_prefix}⏳ Final retry in {wait_time}s...")
|
||||
@@ -5278,6 +5582,12 @@ class AIAgent:
|
||||
invalid_json_args = []
|
||||
for tc in assistant_message.tool_calls:
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, (dict, list)):
|
||||
tc.function.arguments = json.dumps(args)
|
||||
continue
|
||||
if args is not None and not isinstance(args, str):
|
||||
tc.function.arguments = str(args)
|
||||
args = tc.function.arguments
|
||||
# Treat empty/whitespace strings as empty object
|
||||
if not args or not args.strip():
|
||||
tc.function.arguments = "{}"
|
||||
|
||||
Executable
+389
@@ -0,0 +1,389 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Discord Voice Doctor — diagnostic tool for voice channel support.
|
||||
|
||||
Checks all dependencies, configuration, and bot permissions needed
|
||||
for Discord voice mode to work correctly.
|
||||
|
||||
Usage:
|
||||
python scripts/discord-voice-doctor.py
|
||||
.venv/bin/python scripts/discord-voice-doctor.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# Resolve project root
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_ROOT = SCRIPT_DIR.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
ENV_FILE = HERMES_HOME / ".env"
|
||||
|
||||
OK = "\033[92m\u2713\033[0m"
|
||||
FAIL = "\033[91m\u2717\033[0m"
|
||||
WARN = "\033[93m!\033[0m"
|
||||
|
||||
# Track whether discord.py is available for later sections
|
||||
_discord_available = False
|
||||
|
||||
|
||||
def mask(value):
|
||||
"""Mask sensitive value: show only first 4 chars."""
|
||||
if not value or len(value) < 8:
|
||||
return "****"
|
||||
return f"{value[:4]}{'*' * (len(value) - 4)}"
|
||||
|
||||
|
||||
def check(label, ok, detail=""):
|
||||
symbol = OK if ok else FAIL
|
||||
msg = f" {symbol} {label}"
|
||||
if detail:
|
||||
msg += f" ({detail})"
|
||||
print(msg)
|
||||
return ok
|
||||
|
||||
|
||||
def warn(label, detail=""):
|
||||
msg = f" {WARN} {label}"
|
||||
if detail:
|
||||
msg += f" ({detail})"
|
||||
print(msg)
|
||||
|
||||
|
||||
def section(title):
|
||||
print(f"\n\033[1m{title}\033[0m")
|
||||
|
||||
|
||||
def check_packages():
|
||||
"""Check Python package dependencies. Returns True if all critical deps OK."""
|
||||
global _discord_available
|
||||
section("Python Packages")
|
||||
ok = True
|
||||
|
||||
# discord.py
|
||||
try:
|
||||
import discord
|
||||
_discord_available = True
|
||||
check("discord.py", True, f"v{discord.__version__}")
|
||||
except ImportError:
|
||||
check("discord.py", False, "pip install discord.py[voice]")
|
||||
ok = False
|
||||
|
||||
# PyNaCl
|
||||
try:
|
||||
import nacl
|
||||
ver = getattr(nacl, "__version__", "unknown")
|
||||
try:
|
||||
import nacl.secret
|
||||
nacl.secret.Aead(bytes(32))
|
||||
check("PyNaCl", True, f"v{ver}")
|
||||
except (AttributeError, Exception):
|
||||
check("PyNaCl (Aead)", False, f"v{ver} — need >=1.5.0")
|
||||
ok = False
|
||||
except ImportError:
|
||||
check("PyNaCl", False, "pip install PyNaCl>=1.5.0")
|
||||
ok = False
|
||||
|
||||
# davey (DAVE E2EE)
|
||||
try:
|
||||
import davey
|
||||
check("davey (DAVE E2EE)", True, f"v{getattr(davey, '__version__', '?')}")
|
||||
except ImportError:
|
||||
check("davey (DAVE E2EE)", False, "pip install davey")
|
||||
ok = False
|
||||
|
||||
# Optional: local STT
|
||||
try:
|
||||
import faster_whisper
|
||||
check("faster-whisper (local STT)", True)
|
||||
except ImportError:
|
||||
warn("faster-whisper (local STT)", "not installed — local STT unavailable")
|
||||
|
||||
# Optional: TTS providers
|
||||
try:
|
||||
import edge_tts
|
||||
check("edge-tts", True)
|
||||
except ImportError:
|
||||
warn("edge-tts", "not installed — edge TTS unavailable")
|
||||
|
||||
try:
|
||||
import elevenlabs
|
||||
check("elevenlabs SDK", True)
|
||||
except ImportError:
|
||||
warn("elevenlabs SDK", "not installed — premium TTS unavailable")
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def check_system_tools():
|
||||
"""Check system-level tools (opus, ffmpeg). Returns True if all OK."""
|
||||
section("System Tools")
|
||||
ok = True
|
||||
|
||||
# Opus codec
|
||||
if _discord_available:
|
||||
try:
|
||||
import discord
|
||||
opus_loaded = discord.opus.is_loaded()
|
||||
if not opus_loaded:
|
||||
import ctypes.util
|
||||
opus_path = ctypes.util.find_library("opus")
|
||||
if not opus_path:
|
||||
# Platform-specific fallback paths
|
||||
candidates = [
|
||||
"/opt/homebrew/lib/libopus.dylib", # macOS Apple Silicon
|
||||
"/usr/local/lib/libopus.dylib", # macOS Intel
|
||||
"/usr/lib/x86_64-linux-gnu/libopus.so.0", # Debian/Ubuntu x86
|
||||
"/usr/lib/aarch64-linux-gnu/libopus.so.0", # Debian/Ubuntu ARM
|
||||
"/usr/lib/libopus.so", # Arch Linux
|
||||
"/usr/lib64/libopus.so", # RHEL/Fedora
|
||||
]
|
||||
for p in candidates:
|
||||
if os.path.isfile(p):
|
||||
opus_path = p
|
||||
break
|
||||
if opus_path:
|
||||
discord.opus.load_opus(opus_path)
|
||||
opus_loaded = discord.opus.is_loaded()
|
||||
if opus_loaded:
|
||||
check("Opus codec", True)
|
||||
else:
|
||||
check("Opus codec", False, "brew install opus / apt install libopus0")
|
||||
ok = False
|
||||
except Exception as e:
|
||||
check("Opus codec", False, str(e))
|
||||
ok = False
|
||||
else:
|
||||
warn("Opus codec", "skipped — discord.py not installed")
|
||||
|
||||
# ffmpeg
|
||||
ffmpeg_path = shutil.which("ffmpeg")
|
||||
if ffmpeg_path:
|
||||
check("ffmpeg", True, ffmpeg_path)
|
||||
else:
|
||||
check("ffmpeg", False, "brew install ffmpeg / apt install ffmpeg")
|
||||
ok = False
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def check_env_vars():
|
||||
"""Check environment variables. Returns (ok, token, groq_key, eleven_key)."""
|
||||
section("Environment Variables")
|
||||
|
||||
# Load .env
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
if ENV_FILE.exists():
|
||||
load_dotenv(ENV_FILE)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
ok = True
|
||||
|
||||
token = os.getenv("DISCORD_BOT_TOKEN", "")
|
||||
if token:
|
||||
check("DISCORD_BOT_TOKEN", True, mask(token))
|
||||
else:
|
||||
check("DISCORD_BOT_TOKEN", False, "not set")
|
||||
ok = False
|
||||
|
||||
# Allowed users — resolve usernames if possible
|
||||
allowed = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed:
|
||||
users = [u.strip() for u in allowed.split(",") if u.strip()]
|
||||
user_labels = []
|
||||
for uid in users:
|
||||
label = mask(uid)
|
||||
if token and uid.isdigit():
|
||||
try:
|
||||
import requests
|
||||
r = requests.get(
|
||||
f"https://discord.com/api/v10/users/{uid}",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
timeout=3,
|
||||
)
|
||||
if r.status_code == 200:
|
||||
label = f"{r.json().get('username', '?')} ({mask(uid)})"
|
||||
except Exception:
|
||||
pass
|
||||
user_labels.append(label)
|
||||
check("DISCORD_ALLOWED_USERS", True, f"{len(users)} user(s): {', '.join(user_labels)}")
|
||||
else:
|
||||
warn("DISCORD_ALLOWED_USERS", "not set — all users can use voice")
|
||||
|
||||
groq_key = os.getenv("GROQ_API_KEY", "")
|
||||
eleven_key = os.getenv("ELEVENLABS_API_KEY", "")
|
||||
|
||||
if groq_key:
|
||||
check("GROQ_API_KEY (STT)", True, mask(groq_key))
|
||||
else:
|
||||
warn("GROQ_API_KEY", "not set — Groq STT unavailable")
|
||||
|
||||
if eleven_key:
|
||||
check("ELEVENLABS_API_KEY (TTS)", True, mask(eleven_key))
|
||||
else:
|
||||
warn("ELEVENLABS_API_KEY", "not set — ElevenLabs TTS unavailable")
|
||||
|
||||
return ok, token, groq_key, eleven_key
|
||||
|
||||
|
||||
def check_config(groq_key, eleven_key):
|
||||
"""Check hermes config.yaml."""
|
||||
section("Configuration")
|
||||
|
||||
config_path = HERMES_HOME / "config.yaml"
|
||||
if config_path.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
stt_provider = cfg.get("stt", {}).get("provider", "local")
|
||||
tts_provider = cfg.get("tts", {}).get("provider", "edge")
|
||||
check("STT provider", True, stt_provider)
|
||||
check("TTS provider", True, tts_provider)
|
||||
|
||||
if stt_provider == "groq" and not groq_key:
|
||||
warn("STT config says groq but GROQ_API_KEY is missing")
|
||||
if tts_provider == "elevenlabs" and not eleven_key:
|
||||
warn("TTS config says elevenlabs but ELEVENLABS_API_KEY is missing")
|
||||
except Exception as e:
|
||||
warn("config.yaml", f"parse error: {e}")
|
||||
else:
|
||||
warn("config.yaml", "not found — using defaults")
|
||||
|
||||
# Voice mode state
|
||||
voice_mode_path = HERMES_HOME / "gateway_voice_mode.json"
|
||||
if voice_mode_path.exists():
|
||||
try:
|
||||
import json
|
||||
modes = json.loads(voice_mode_path.read_text())
|
||||
off_count = sum(1 for v in modes.values() if v == "off")
|
||||
all_count = sum(1 for v in modes.values() if v == "all")
|
||||
check("Voice mode state", True, f"{all_count} on, {off_count} off, {len(modes)} total")
|
||||
except Exception:
|
||||
warn("Voice mode state", "parse error")
|
||||
else:
|
||||
check("Voice mode state", True, "no saved state (fresh)")
|
||||
|
||||
|
||||
def check_bot_permissions(token):
|
||||
"""Check bot permissions via Discord API. Returns True if all OK."""
|
||||
section("Bot Permissions")
|
||||
|
||||
if not token:
|
||||
warn("Bot permissions", "no token — skipping")
|
||||
return True
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
warn("Bot permissions", "requests not installed — skipping")
|
||||
return True
|
||||
|
||||
VOICE_PERMS = {
|
||||
"Priority Speaker": 8,
|
||||
"Stream": 9,
|
||||
"View Channel": 10,
|
||||
"Send Messages": 11,
|
||||
"Embed Links": 14,
|
||||
"Attach Files": 15,
|
||||
"Read Message History": 16,
|
||||
"Connect": 20,
|
||||
"Speak": 21,
|
||||
"Mute Members": 22,
|
||||
"Deafen Members": 23,
|
||||
"Move Members": 24,
|
||||
"Use VAD": 25,
|
||||
"Send Voice Messages": 46,
|
||||
}
|
||||
REQUIRED_PERMS = {"Connect", "Speak", "View Channel", "Send Messages"}
|
||||
ok = True
|
||||
|
||||
try:
|
||||
headers = {"Authorization": f"Bot {token}"}
|
||||
r = requests.get("https://discord.com/api/v10/users/@me", headers=headers, timeout=5)
|
||||
|
||||
if r.status_code == 401:
|
||||
check("Bot login", False, "invalid token (401)")
|
||||
return False
|
||||
if r.status_code != 200:
|
||||
check("Bot login", False, f"HTTP {r.status_code}")
|
||||
return False
|
||||
|
||||
bot = r.json()
|
||||
bot_name = bot.get("username", "?")
|
||||
check("Bot login", True, f"{bot_name[:3]}{'*' * (len(bot_name) - 3)}")
|
||||
|
||||
# Check guilds
|
||||
r2 = requests.get("https://discord.com/api/v10/users/@me/guilds", headers=headers, timeout=5)
|
||||
if r2.status_code != 200:
|
||||
warn("Guilds", f"HTTP {r2.status_code}")
|
||||
return ok
|
||||
|
||||
guilds = r2.json()
|
||||
check("Guilds", True, f"{len(guilds)} guild(s)")
|
||||
|
||||
for g in guilds[:5]:
|
||||
perms = int(g.get("permissions", 0))
|
||||
is_admin = bool(perms & (1 << 3))
|
||||
|
||||
if is_admin:
|
||||
print(f" {OK} {g['name']}: Administrator (all permissions)")
|
||||
continue
|
||||
|
||||
has = []
|
||||
missing = []
|
||||
for name, bit in sorted(VOICE_PERMS.items(), key=lambda x: x[1]):
|
||||
if perms & (1 << bit):
|
||||
has.append(name)
|
||||
elif name in REQUIRED_PERMS:
|
||||
missing.append(name)
|
||||
|
||||
if missing:
|
||||
print(f" {FAIL} {g['name']}: missing {', '.join(missing)}")
|
||||
ok = False
|
||||
else:
|
||||
print(f" {OK} {g['name']}: {', '.join(has)}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
warn("Bot permissions", "Discord API timeout")
|
||||
except requests.exceptions.ConnectionError:
|
||||
warn("Bot permissions", "cannot reach Discord API")
|
||||
except Exception as e:
|
||||
warn("Bot permissions", f"check failed: {e}")
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def main():
|
||||
print()
|
||||
print("\033[1m" + "=" * 50 + "\033[0m")
|
||||
print("\033[1m Discord Voice Doctor\033[0m")
|
||||
print("\033[1m" + "=" * 50 + "\033[0m")
|
||||
|
||||
all_ok = True
|
||||
|
||||
all_ok &= check_packages()
|
||||
all_ok &= check_system_tools()
|
||||
env_ok, token, groq_key, eleven_key = check_env_vars()
|
||||
all_ok &= env_ok
|
||||
check_config(groq_key, eleven_key)
|
||||
all_ok &= check_bot_permissions(token)
|
||||
|
||||
# Summary
|
||||
print()
|
||||
print("\033[1m" + "-" * 50 + "\033[0m")
|
||||
if all_ok:
|
||||
print(f" {OK} \033[92mAll checks passed — voice mode ready!\033[0m")
|
||||
else:
|
||||
print(f" {FAIL} \033[91mSome checks failed — fix issues above.\033[0m")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -155,7 +155,7 @@ terminal(command="hermes chat -q 'Summarize this codebase' --model google/gemini
|
||||
|
||||
## Gateway Cron Integration
|
||||
|
||||
For scheduled autonomous tasks, use the `schedule_cronjob` tool instead of spawning processes — cron jobs handle delivery, retry, and persistence automatically.
|
||||
For scheduled autonomous tasks, use the unified `cronjob` tool instead of spawning processes — cron jobs handle delivery, retry, and persistence automatically.
|
||||
|
||||
## Key Differences Between Modes
|
||||
|
||||
|
||||
@@ -3240,7 +3240,7 @@ Prompt Strategy for finetuning Llama2 chat models see also https://github.com/fa
|
||||
|
||||
This implementation is based on the Vicuna PR and the fastchat repo, see also: https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847
|
||||
|
||||
Use dataset type: “llama2_chat” in conig.yml to use this prompt style.
|
||||
Use dataset type: “llama2_chat” in config.yml to use this prompt style.
|
||||
|
||||
E.g. in the config.yml:
|
||||
|
||||
@@ -4991,7 +4991,7 @@ prompt_strategies.orcamini
|
||||
|
||||
Prompt Strategy for finetuning Orca Mini (v2) models see also https://huggingface.co/psmathur/orca_mini_v2_7b for more information
|
||||
|
||||
Use dataset type: orcamini in conig.yml to use this prompt style.
|
||||
Use dataset type: orcamini in config.yml to use this prompt style.
|
||||
|
||||
Compared to the alpaca_w_system.open_orca dataset type, this one specifies the system prompt with “### System:”.
|
||||
|
||||
|
||||
@@ -2290,7 +2290,7 @@ This call gives the AsyncStager the opportunity to ‘stage’ the state_dict. T
|
||||
|
||||
for serializing the state_dict and writing it to storage.
|
||||
|
||||
the serialization thread starts and before returning from dcp.async_save. If this is set to False, the assumption is the user has defined a custom synchronization point for the the purpose of further optimizing save latency in the training loop (for example, by overlapping staging with the forward/backward pass), and it is the respondsibility of the user to call AsyncStager.synchronize_staging at the appropriate time.
|
||||
the serialization thread starts and before returning from dcp.async_save. If this is set to False, the assumption is the user has defined a custom synchronization point for the purpose of further optimizing save latency in the training loop (for example, by overlapping staging with the forward/backward pass), and it is the respondsibility of the user to call AsyncStager.synchronize_staging at the appropriate time.
|
||||
|
||||
Clean up all resources used by the stager.
|
||||
|
||||
@@ -2430,7 +2430,7 @@ Read the checkpoint metadata.
|
||||
|
||||
The metadata object associated with the checkpoint being loaded.
|
||||
|
||||
Calls to indicates a brand new checkpoint read is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint read. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage.
|
||||
Calls to indicates a brand new checkpoint read is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint read. The meaning of the checkpoint_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage.
|
||||
|
||||
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is more like a key-value store. (Default: None)
|
||||
|
||||
@@ -2488,7 +2488,7 @@ plan (SavePlan) – The local plan from the SavePlanner in use.
|
||||
|
||||
A transformed SavePlan after storage local planning
|
||||
|
||||
Calls to indicates a brand new checkpoint write is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint write. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage.
|
||||
Calls to indicates a brand new checkpoint write is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint write. The meaning of the checkpoint_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage.
|
||||
|
||||
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: None)
|
||||
|
||||
@@ -2498,7 +2498,19 @@ is_coordinator (bool) – Whether this instance is responsible for coordinating
|
||||
|
||||
Return the storage-specific metadata. This is used to store additional information in a checkpoint that can be useful for providing request-level observability. StorageMeta is passed to the SavePlanner during save calls. Returns None by default.
|
||||
|
||||
TODO: provide an example
|
||||
Example:
|
||||
|
||||
```python
|
||||
from torch.distributed.checkpoint.storage import StorageMeta
|
||||
|
||||
class CustomStorageBackend:
|
||||
def get_storage_metadata(self):
|
||||
# Return storage-specific metadata that will be stored with the checkpoint
|
||||
return StorageMeta()
|
||||
```
|
||||
|
||||
This example shows how a storage backend can return `StorageMeta`
|
||||
to attach additional metadata to a checkpoint.
|
||||
|
||||
Optional[StorageMeta]
|
||||
|
||||
@@ -3441,7 +3453,7 @@ The target module does not have to be an FSDP module.
|
||||
|
||||
A StateDictSettings containing the state_dict_type and state_dict / optim_state_dict configs that are currently set.
|
||||
|
||||
AssertionError` if the StateDictSettings for differen –
|
||||
AssertionError` if the StateDictSettings for different –
|
||||
|
||||
FSDP submodules differ. –
|
||||
|
||||
@@ -3766,7 +3778,7 @@ The sharing is done as described by ZeRO.
|
||||
|
||||
The local optimizer instance in each rank is only responsible for updating approximately 1 / world_size parameters and hence only needs to keep 1 / world_size optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state. ZeroRedundancyOptimizer can be used in conjunction with torch.nn.parallel.DistributedDataParallel to reduce per-rank peak memory consumption.
|
||||
|
||||
ZeroRedundancyOptimizer uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order.
|
||||
ZeroRedundancyOptimizer uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the parameter registration or usage order.
|
||||
|
||||
params (Iterable) – an Iterable of torch.Tensor s or dict s giving all parameters, which will be sharded across ranks.
|
||||
|
||||
|
||||
@@ -6348,7 +6348,7 @@ Our chat templates for the GGUF, our BnB and BF16 uploads and all versions are f
|
||||
|
||||
### :1234: Precision issues
|
||||
|
||||
We found multiple precision issues in Tesla T4 and float16 machines primarily since the model was trained using BF16, and so outliers and overflows existed. MXFP4 is not actually supported on Ampere and older GPUs, so Triton provides `tl.dot_scaled` for MXFP4 matrix multiplication. It upcasts the matrices to BF16 internaly on the fly.
|
||||
We found multiple precision issues in Tesla T4 and float16 machines primarily since the model was trained using BF16, and so outliers and overflows existed. MXFP4 is not actually supported on Ampere and older GPUs, so Triton provides `tl.dot_scaled` for MXFP4 matrix multiplication. It upcasts the matrices to BF16 internally on the fly.
|
||||
|
||||
We made a [MXFP4 inference notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/GPT_OSS_MXFP4_\(20B\)-Inference.ipynb) as well in Tesla T4 Colab!
|
||||
|
||||
@@ -14877,7 +14877,7 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \
|
||||
|
||||
# Text-to-Speech (TTS) Fine-tuning
|
||||
|
||||
Learn how to to fine-tune TTS & STT voice models with Unsloth.
|
||||
Learn how to fine-tune TTS & STT voice models with Unsloth.
|
||||
|
||||
Fine-tuning TTS models allows them to adapt to your specific dataset, use case, or desired style and tone. The goal is to customize these models to clone voices, adapt speaking styles and tones, support new languages, handle specific tasks and more. We also support **Speech-to-Text (STT)** models like OpenAI's Whisper.
|
||||
|
||||
@@ -15306,7 +15306,7 @@ snapshot_download(
|
||||
)
|
||||
```
|
||||
|
||||
And and let's do inference!
|
||||
And let's do inference!
|
||||
|
||||
{% code overflow="wrap" %}
|
||||
|
||||
@@ -16036,7 +16036,7 @@ Then train the model as usual via `trainer.train() .`
|
||||
|
||||
Tips to solve issues, and frequently asked questions.
|
||||
|
||||
If you're still encountering any issues with versions or depencies, please use our [Docker image](https://docs.unsloth.ai/get-started/install-and-update/docker) which will have everything pre-installed.
|
||||
If you're still encountering any issues with versions or dependencies, please use our [Docker image](https://docs.unsloth.ai/get-started/install-and-update/docker) which will have everything pre-installed.
|
||||
|
||||
{% hint style="success" %}
|
||||
**Try always to update Unsloth if you find any issues.**
|
||||
|
||||
@@ -40,7 +40,7 @@ Read more on running Llama 4 here: <https://docs.unsloth.ai/basics/tutorial-how-
|
||||
|
||||
Example 1 (unknown):
|
||||
```unknown
|
||||
And and let's do inference!
|
||||
And let's do inference!
|
||||
|
||||
{% code overflow="wrap" %}
|
||||
```
|
||||
@@ -4272,7 +4272,7 @@ Read our full DeepSeek-R1 blogpost here: [unsloth.ai/blog/deepseekr1-dynamic](ht
|
||||
|
||||
Tips to solve issues, and frequently asked questions.
|
||||
|
||||
If you're still encountering any issues with versions or depencies, please use our [Docker image](https://docs.unsloth.ai/get-started/install-and-update/docker) which will have everything pre-installed.
|
||||
If you're still encountering any issues with versions or dependencies, please use our [Docker image](https://docs.unsloth.ai/get-started/install-and-update/docker) which will have everything pre-installed.
|
||||
|
||||
{% hint style="success" %}
|
||||
**Try always to update Unsloth if you find any issues.**
|
||||
@@ -6638,7 +6638,7 @@ Our chat templates for the GGUF, our BnB and BF16 uploads and all versions are f
|
||||
|
||||
### :1234: Precision issues
|
||||
|
||||
We found multiple precision issues in Tesla T4 and float16 machines primarily since the model was trained using BF16, and so outliers and overflows existed. MXFP4 is not actually supported on Ampere and older GPUs, so Triton provides `tl.dot_scaled` for MXFP4 matrix multiplication. It upcasts the matrices to BF16 internaly on the fly.
|
||||
We found multiple precision issues in Tesla T4 and float16 machines primarily since the model was trained using BF16, and so outliers and overflows existed. MXFP4 is not actually supported on Ampere and older GPUs, so Triton provides `tl.dot_scaled` for MXFP4 matrix multiplication. It upcasts the matrices to BF16 internally on the fly.
|
||||
|
||||
We made a [MXFP4 inference notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/GPT_OSS_MXFP4_\(20B\)-Inference.ipynb) as well in Tesla T4 Colab!
|
||||
|
||||
@@ -10259,7 +10259,7 @@ training_args = GRPOConfig(
|
||||
- Choosing and Loading a TTS Model
|
||||
- Preparing Your Dataset
|
||||
|
||||
Learn how to to fine-tune TTS & STT voice models with Unsloth.
|
||||
Learn how to fine-tune TTS & STT voice models with Unsloth.
|
||||
|
||||
Fine-tuning TTS models allows them to adapt to your specific dataset, use case, or desired style and tone. The goal is to customize these models to clone voices, adapt speaking styles and tones, support new languages, handle specific tasks and more. We also support **Speech-to-Text (STT)** models like OpenAI's Whisper.
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@
|
||||
- [Troubleshooting Inference](/basics/running-and-saving-models/troubleshooting-inference.md): If you're experiencing issues when running or saving your model.
|
||||
- [vLLM Engine Arguments](/basics/running-and-saving-models/vllm-engine-arguments.md)
|
||||
- [LoRA Hot Swapping Guide](/basics/running-and-saving-models/lora-hot-swapping-guide.md)
|
||||
- [Text-to-Speech (TTS) Fine-tuning](/basics/text-to-speech-tts-fine-tuning.md): Learn how to to fine-tune TTS & STT voice models with Unsloth.
|
||||
- [Text-to-Speech (TTS) Fine-tuning](/basics/text-to-speech-tts-fine-tuning.md): Learn how to fine-tune TTS & STT voice models with Unsloth.
|
||||
- [Unsloth Dynamic 2.0 GGUFs](/basics/unsloth-dynamic-2.0-ggufs.md): A big new upgrade to our Dynamic Quants!
|
||||
- [Vision Fine-tuning](/basics/vision-fine-tuning.md): Learn how to fine-tune vision/multimodal LLMs with Unsloth
|
||||
- [Fine-tuning LLMs with NVIDIA DGX Spark and Unsloth](/basics/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth.md): Tutorial on how to fine-tune and do reinforcement learning (RL) with OpenAI gpt-oss on NVIDIA DGX Spark.
|
||||
|
||||
@@ -102,7 +102,9 @@ This prints a URL. **Send the URL to the user** and tell them:
|
||||
### Step 4: Exchange the code
|
||||
|
||||
The user will paste back either a URL like `http://localhost:1/?code=4/0A...&scope=...`
|
||||
or just the code string. Either works:
|
||||
or just the code string. Either works. The `--auth-url` step stores a temporary
|
||||
pending OAuth session locally so `--auth-code` can complete the PKCE exchange
|
||||
later, even on headless systems:
|
||||
|
||||
```bash
|
||||
$GSETUP --auth-code "THE_URL_OR_CODE_THE_USER_PASTED"
|
||||
@@ -119,6 +121,7 @@ Should print `AUTHENTICATED`. Setup is complete — token refreshes automaticall
|
||||
### Notes
|
||||
|
||||
- Token is stored at `~/.hermes/google_token.json` and auto-refreshes.
|
||||
- Pending OAuth session state/verifier are stored temporarily at `~/.hermes/google_oauth_pending.json` until exchange completes.
|
||||
- To revoke: `$GSETUP --revoke`
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -31,6 +31,7 @@ from pathlib import Path
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
TOKEN_PATH = HERMES_HOME / "google_token.json"
|
||||
CLIENT_SECRET_PATH = HERMES_HOME / "google_client_secret.json"
|
||||
PENDING_AUTH_PATH = HERMES_HOME / "google_oauth_pending.json"
|
||||
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
@@ -141,6 +142,58 @@ def store_client_secret(path: str):
|
||||
print(f"OK: Client secret saved to {CLIENT_SECRET_PATH}")
|
||||
|
||||
|
||||
def _save_pending_auth(*, state: str, code_verifier: str):
|
||||
"""Persist the OAuth session bits needed for a later token exchange."""
|
||||
PENDING_AUTH_PATH.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"state": state,
|
||||
"code_verifier": code_verifier,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _load_pending_auth() -> dict:
|
||||
"""Load the pending OAuth session created by get_auth_url()."""
|
||||
if not PENDING_AUTH_PATH.exists():
|
||||
print("ERROR: No pending OAuth session found. Run --auth-url first.")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
data = json.loads(PENDING_AUTH_PATH.read_text())
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not read pending OAuth session: {e}")
|
||||
print("Run --auth-url again to start a fresh OAuth session.")
|
||||
sys.exit(1)
|
||||
|
||||
if not data.get("state") or not data.get("code_verifier"):
|
||||
print("ERROR: Pending OAuth session is missing PKCE data.")
|
||||
print("Run --auth-url again to start a fresh OAuth session.")
|
||||
sys.exit(1)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _extract_code_and_state(code_or_url: str) -> tuple[str, str | None]:
|
||||
"""Accept either a raw auth code or the full redirect URL pasted by the user."""
|
||||
if not code_or_url.startswith("http"):
|
||||
return code_or_url, None
|
||||
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
parsed = urlparse(code_or_url)
|
||||
params = parse_qs(parsed.query)
|
||||
if "code" not in params:
|
||||
print("ERROR: No 'code' parameter found in URL.")
|
||||
sys.exit(1)
|
||||
|
||||
state = params.get("state", [None])[0]
|
||||
return params["code"][0], state
|
||||
|
||||
|
||||
def get_auth_url():
|
||||
"""Print the OAuth authorization URL. User visits this in a browser."""
|
||||
if not CLIENT_SECRET_PATH.exists():
|
||||
@@ -154,11 +207,13 @@ def get_auth_url():
|
||||
str(CLIENT_SECRET_PATH),
|
||||
scopes=SCOPES,
|
||||
redirect_uri=REDIRECT_URI,
|
||||
autogenerate_code_verifier=True,
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(
|
||||
auth_url, state = flow.authorization_url(
|
||||
access_type="offline",
|
||||
prompt="consent",
|
||||
)
|
||||
_save_pending_auth(state=state, code_verifier=flow.code_verifier)
|
||||
# Print just the URL so the agent can extract it cleanly
|
||||
print(auth_url)
|
||||
|
||||
@@ -169,26 +224,23 @@ def exchange_auth_code(code: str):
|
||||
print("ERROR: No client secret stored. Run --client-secret first.")
|
||||
sys.exit(1)
|
||||
|
||||
pending_auth = _load_pending_auth()
|
||||
code, returned_state = _extract_code_and_state(code)
|
||||
if returned_state and returned_state != pending_auth["state"]:
|
||||
print("ERROR: OAuth state mismatch. Run --auth-url again to start a fresh session.")
|
||||
sys.exit(1)
|
||||
|
||||
_ensure_deps()
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
|
||||
flow = Flow.from_client_secrets_file(
|
||||
str(CLIENT_SECRET_PATH),
|
||||
scopes=SCOPES,
|
||||
redirect_uri=REDIRECT_URI,
|
||||
redirect_uri=pending_auth.get("redirect_uri", REDIRECT_URI),
|
||||
state=pending_auth["state"],
|
||||
code_verifier=pending_auth["code_verifier"],
|
||||
)
|
||||
|
||||
# The code might come as a full redirect URL or just the code itself
|
||||
if code.startswith("http"):
|
||||
# Extract code from redirect URL: http://localhost:1/?code=CODE&scope=...
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
parsed = urlparse(code)
|
||||
params = parse_qs(parsed.query)
|
||||
if "code" not in params:
|
||||
print("ERROR: No 'code' parameter found in URL.")
|
||||
sys.exit(1)
|
||||
code = params["code"][0]
|
||||
|
||||
try:
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as e:
|
||||
@@ -198,6 +250,7 @@ def exchange_auth_code(code: str):
|
||||
|
||||
creds = flow.credentials
|
||||
TOKEN_PATH.write_text(creds.to_json())
|
||||
PENDING_AUTH_PATH.unlink(missing_ok=True)
|
||||
print(f"OK: Authenticated. Token saved to {TOKEN_PATH}")
|
||||
|
||||
|
||||
@@ -229,6 +282,7 @@ def revoke():
|
||||
print(f"Remote revocation failed (token may already be invalid): {e}")
|
||||
|
||||
TOKEN_PATH.unlink(missing_ok=True)
|
||||
PENDING_AUTH_PATH.unlink(missing_ok=True)
|
||||
print(f"Deleted {TOKEN_PATH}")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
---
|
||||
name: plan
|
||||
description: Plan mode for Hermes — inspect context, write a markdown plan into the active workspace's `.hermes/plans/` directory, and do not execute the work.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [planning, plan-mode, implementation, workflow]
|
||||
related_skills: [writing-plans, subagent-driven-development]
|
||||
---
|
||||
|
||||
# Plan Mode
|
||||
|
||||
Use this skill when the user wants a plan instead of execution.
|
||||
|
||||
## Core behavior
|
||||
|
||||
For this turn, you are planning only.
|
||||
|
||||
- Do not implement code.
|
||||
- Do not edit project files except the plan markdown file.
|
||||
- Do not run mutating terminal commands, commit, push, or perform external actions.
|
||||
- You may inspect the repo or other context with read-only commands/tools when needed.
|
||||
- Your deliverable is a markdown plan saved inside the active workspace under `.hermes/plans/`.
|
||||
|
||||
## Output requirements
|
||||
|
||||
Write a markdown plan that is concrete and actionable.
|
||||
|
||||
Include, when relevant:
|
||||
- Goal
|
||||
- Current context / assumptions
|
||||
- Proposed approach
|
||||
- Step-by-step plan
|
||||
- Files likely to change
|
||||
- Tests / validation
|
||||
- Risks, tradeoffs, and open questions
|
||||
|
||||
If the task is code-related, include exact file paths, likely test targets, and verification steps.
|
||||
|
||||
## Save location
|
||||
|
||||
Save the plan with `write_file` under:
|
||||
- `.hermes/plans/YYYY-MM-DD_HHMMSS-<slug>.md`
|
||||
|
||||
Treat that as relative to the active working directory / backend workspace. Hermes file tools are backend-aware, so using this relative path keeps the plan with the workspace on local, docker, ssh, modal, and daytona backends.
|
||||
|
||||
If the runtime provides a specific target path, use that exact path.
|
||||
If not, create a sensible timestamped filename yourself under `.hermes/plans/`.
|
||||
|
||||
## Interaction style
|
||||
|
||||
- If the request is clear enough, write the plan directly.
|
||||
- If no explicit instruction accompanies `/plan`, infer the task from the current conversation context.
|
||||
- If it is genuinely underspecified, ask a brief clarifying question instead of guessing.
|
||||
- After saving the plan, reply briefly with what you planned and the saved path.
|
||||
@@ -10,6 +10,8 @@ import pytest
|
||||
from agent.auxiliary_client import (
|
||||
get_text_auxiliary_client,
|
||||
get_vision_auxiliary_client,
|
||||
get_available_vision_backends,
|
||||
resolve_provider_client,
|
||||
auxiliary_max_tokens_param,
|
||||
_read_codex_access_token,
|
||||
_get_auxiliary_provider,
|
||||
@@ -24,9 +26,12 @@ def _clean_env(monkeypatch):
|
||||
for key in (
|
||||
"OPENROUTER_API_KEY", "OPENAI_BASE_URL", "OPENAI_API_KEY",
|
||||
"OPENAI_MODEL", "LLM_MODEL", "NOUS_INFERENCE_BASE_URL",
|
||||
# Per-task provider/model overrides
|
||||
"ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN",
|
||||
# Per-task provider/model/direct-endpoint overrides
|
||||
"AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL",
|
||||
"AUXILIARY_VISION_BASE_URL", "AUXILIARY_VISION_API_KEY",
|
||||
"AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL",
|
||||
"AUXILIARY_WEB_EXTRACT_BASE_URL", "AUXILIARY_WEB_EXTRACT_API_KEY",
|
||||
"CONTEXT_COMPRESSION_PROVIDER", "CONTEXT_COMPRESSION_MODEL",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
@@ -142,11 +147,55 @@ class TestGetTextAuxiliaryClient:
|
||||
call_kwargs = mock_openai.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == "http://localhost:1234/v1"
|
||||
|
||||
def test_task_direct_endpoint_override(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_BASE_URL", "http://localhost:2345/v1")
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_API_KEY", "task-key")
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_MODEL", "task-model")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client("web_extract")
|
||||
assert model == "task-model"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:2345/v1"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "task-key"
|
||||
|
||||
def test_task_direct_endpoint_without_openai_key_does_not_fall_back(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_BASE_URL", "http://localhost:2345/v1")
|
||||
monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_MODEL", "task-model")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client("web_extract")
|
||||
assert client is None
|
||||
assert model is None
|
||||
mock_openai.assert_not_called()
|
||||
|
||||
def test_custom_endpoint_uses_config_saved_base_url(self, monkeypatch):
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "lm-studio-key")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client()
|
||||
|
||||
assert client is not None
|
||||
assert model == "my-local-model"
|
||||
call_kwargs = mock_openai.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == "http://localhost:1234/v1"
|
||||
|
||||
def test_codex_fallback_when_nothing_else(self, codex_auth_dir):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client()
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
# Returns a CodexAuxiliaryClient wrapper, not a raw OpenAI client
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
@@ -164,14 +213,74 @@ class TestGetTextAuxiliaryClient:
|
||||
|
||||
|
||||
class TestVisionClientFallback:
|
||||
"""Vision client auto mode only tries OpenRouter + Nous (multimodal-capable)."""
|
||||
"""Vision client auto mode resolves known-good multimodal backends."""
|
||||
|
||||
def test_vision_returns_none_without_any_credentials(self):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._try_anthropic", return_value=(None, None)),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_auto_includes_anthropic_when_configured(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
|
||||
):
|
||||
backends = get_available_vision_backends()
|
||||
|
||||
assert "anthropic" in backends
|
||||
|
||||
def test_resolve_provider_client_returns_native_anthropic_wrapper(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
|
||||
):
|
||||
client, model = resolve_provider_client("anthropic")
|
||||
|
||||
assert client is not None
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
|
||||
def test_vision_auto_uses_anthropic_when_no_higher_priority_backend(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
||||
assert client is not None
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
|
||||
def test_selected_anthropic_provider_is_preferred_for_vision_auto(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
|
||||
def fake_load_config():
|
||||
return {"model": {"provider": "anthropic", "default": "claude-sonnet-4-6"}}
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
patch("hermes_cli.config.load_config", fake_load_config),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
||||
assert client is not None
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
|
||||
def test_vision_auto_includes_codex(self, codex_auth_dir):
|
||||
"""Codex supports vision (gpt-5.3-codex), so auto mode should use it."""
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
@@ -179,7 +288,7 @@ class TestVisionClientFallback:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_vision_auto_falls_back_to_custom_endpoint(self, monkeypatch):
|
||||
"""Custom endpoint is used as fallback in vision auto mode.
|
||||
@@ -194,6 +303,27 @@ class TestVisionClientFallback:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is not None # Custom endpoint picked up as fallback
|
||||
|
||||
def test_vision_direct_endpoint_override(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_API_KEY", "vision-key")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert model == "vision-model"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:4567/v1"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "vision-key"
|
||||
|
||||
def test_vision_direct_endpoint_requires_openai_api_key(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1")
|
||||
monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is None
|
||||
assert model is None
|
||||
mock_openai.assert_not_called()
|
||||
|
||||
def test_vision_uses_openrouter_when_available(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
@@ -241,7 +371,7 @@ class TestVisionClientFallback:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
class TestGetAuxiliaryProvider:
|
||||
@@ -320,6 +450,27 @@ class TestResolveForcedProvider:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
assert model == "my-local-model"
|
||||
|
||||
def test_forced_main_uses_config_saved_custom_endpoint(self, monkeypatch):
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://local:8080/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
assert client is not None
|
||||
assert model == "my-local-model"
|
||||
call_kwargs = mock_openai.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == "http://local:8080/v1"
|
||||
|
||||
def test_forced_main_skips_openrouter_nous(self, monkeypatch):
|
||||
"""Even if OpenRouter key is set, 'main' skips it."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
@@ -338,7 +489,7 @@ class TestResolveForcedProvider:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_forced_codex(self, codex_auth_dir, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
@@ -346,7 +497,7 @@ class TestResolveForcedProvider:
|
||||
client, model = _resolve_forced_provider("codex")
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_forced_codex_no_token(self, monkeypatch):
|
||||
with patch("agent.auxiliary_client._read_codex_access_token", return_value=None):
|
||||
@@ -390,6 +541,24 @@ class TestTaskSpecificOverrides:
|
||||
client, model = get_text_auxiliary_client("web_extract")
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_task_direct_endpoint_from_config(self, monkeypatch, tmp_path):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"""auxiliary:
|
||||
web_extract:
|
||||
base_url: http://localhost:3456/v1
|
||||
api_key: config-key
|
||||
model: config-model
|
||||
"""
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_text_auxiliary_client("web_extract")
|
||||
assert model == "config-model"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:3456/v1"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "config-key"
|
||||
|
||||
def test_task_without_override_uses_auto(self, monkeypatch):
|
||||
"""A task with no provider env var falls through to auto chain."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
@@ -15,10 +15,30 @@ from agent.prompt_builder import (
|
||||
build_context_files_prompt,
|
||||
CONTEXT_FILE_MAX_CHARS,
|
||||
DEFAULT_AGENT_IDENTITY,
|
||||
MEMORY_GUIDANCE,
|
||||
SESSION_SEARCH_GUIDANCE,
|
||||
PLATFORM_HINTS,
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Guidance constants
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestGuidanceConstants:
|
||||
def test_memory_guidance_discourages_task_logs(self):
|
||||
assert "durable facts" in MEMORY_GUIDANCE
|
||||
assert "Do NOT save task progress" in MEMORY_GUIDANCE
|
||||
assert "session_search" in MEMORY_GUIDANCE
|
||||
assert "like a diary" not in MEMORY_GUIDANCE
|
||||
assert ">80%" not in MEMORY_GUIDANCE
|
||||
|
||||
def test_session_search_guidance_is_simple_cross_session_recall(self):
|
||||
assert "relevant cross-session context exists" in SESSION_SEARCH_GUIDANCE
|
||||
assert "recent turns of the current session" not in SESSION_SEARCH_GUIDANCE
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Context injection scanning
|
||||
# =========================================================================
|
||||
@@ -435,6 +455,7 @@ class TestPromptBuilderConstants:
|
||||
assert "whatsapp" in PLATFORM_HINTS
|
||||
assert "telegram" in PLATFORM_HINTS
|
||||
assert "discord" in PLATFORM_HINTS
|
||||
assert "cron" in PLATFORM_HINTS
|
||||
assert "cli" in PLATFORM_HINTS
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
"""Tests for agent/skill_commands.py — skill slash command scanning and platform filtering."""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import tools.skills_tool as skills_tool_module
|
||||
from agent.skill_commands import scan_skill_commands, build_skill_invocation_message
|
||||
from agent.skill_commands import (
|
||||
build_plan_path,
|
||||
build_preloaded_skills_prompt,
|
||||
build_skill_invocation_message,
|
||||
scan_skill_commands,
|
||||
)
|
||||
|
||||
|
||||
def _make_skill(
|
||||
@@ -79,6 +86,33 @@ class TestScanSkillCommands:
|
||||
assert "/generic-tool" in result
|
||||
|
||||
|
||||
class TestBuildPreloadedSkillsPrompt:
|
||||
def test_builds_prompt_for_multiple_named_skills(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(tmp_path, "first-skill")
|
||||
_make_skill(tmp_path, "second-skill")
|
||||
prompt, loaded, missing = build_preloaded_skills_prompt(
|
||||
["first-skill", "second-skill"]
|
||||
)
|
||||
|
||||
assert missing == []
|
||||
assert loaded == ["first-skill", "second-skill"]
|
||||
assert "first-skill" in prompt
|
||||
assert "second-skill" in prompt
|
||||
assert "preloaded" in prompt.lower()
|
||||
|
||||
def test_reports_missing_named_skills(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(tmp_path, "present-skill")
|
||||
prompt, loaded, missing = build_preloaded_skills_prompt(
|
||||
["present-skill", "missing-skill"]
|
||||
)
|
||||
|
||||
assert "present-skill" in prompt
|
||||
assert loaded == ["present-skill"]
|
||||
assert missing == ["missing-skill"]
|
||||
|
||||
|
||||
class TestBuildSkillInvocationMessage:
|
||||
def test_loads_skill_by_stored_path_when_frontmatter_name_differs(self, tmp_path):
|
||||
skill_dir = tmp_path / "mlops" / "audiocraft"
|
||||
@@ -241,3 +275,37 @@ Generate some audio.
|
||||
|
||||
assert msg is not None
|
||||
assert 'file_path="<path>"' in msg
|
||||
|
||||
|
||||
class TestPlanSkillHelpers:
|
||||
def test_build_plan_path_uses_workspace_relative_dir_and_slugifies_request(self):
|
||||
path = build_plan_path(
|
||||
"Implement OAuth login + refresh tokens!",
|
||||
now=datetime(2026, 3, 15, 9, 30, 45),
|
||||
)
|
||||
|
||||
assert path == Path(".hermes") / "plans" / "2026-03-15_093045-implement-oauth-login-refresh-tokens.md"
|
||||
|
||||
def test_plan_skill_message_can_include_runtime_save_path_note(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(
|
||||
tmp_path,
|
||||
"plan",
|
||||
body="Save plans under .hermes/plans in the active workspace and do not execute the work.",
|
||||
)
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message(
|
||||
"/plan",
|
||||
"Add a /plan command",
|
||||
runtime_note=(
|
||||
"Save the markdown plan with write_file to this exact relative path inside "
|
||||
"the active workspace/backend cwd: .hermes/plans/plan.md"
|
||||
),
|
||||
)
|
||||
|
||||
assert msg is not None
|
||||
assert "Save plans under $HERMES_HOME/plans" not in msg
|
||||
assert ".hermes/plans" in msg
|
||||
assert "Add a /plan command" in msg
|
||||
assert ".hermes/plans/plan.md" in msg
|
||||
assert "Runtime note:" in msg
|
||||
|
||||
@@ -26,6 +26,12 @@ def _isolate_hermes_home(tmp_path, monkeypatch):
|
||||
(fake_home / "memories").mkdir()
|
||||
(fake_home / "skills").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(fake_home))
|
||||
# Tests should not inherit the agent's current gateway/messaging surface.
|
||||
# Individual tests that need gateway behavior set these explicitly.
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_SESSION", raising=False)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
+27
-1
@@ -16,6 +16,8 @@ from cron.jobs import (
|
||||
get_job,
|
||||
list_jobs,
|
||||
update_job,
|
||||
pause_job,
|
||||
resume_job,
|
||||
remove_job,
|
||||
mark_job_run,
|
||||
get_due_jobs,
|
||||
@@ -233,14 +235,18 @@ class TestUpdateJob:
|
||||
job = create_job(prompt="Daily report", schedule="every 1h")
|
||||
assert job["schedule"]["kind"] == "interval"
|
||||
assert job["schedule"]["minutes"] == 60
|
||||
old_next_run = job["next_run_at"]
|
||||
new_schedule = parse_schedule("every 2h")
|
||||
updated = update_job(job["id"], {"schedule": new_schedule})
|
||||
updated = update_job(job["id"], {"schedule": new_schedule, "schedule_display": new_schedule["display"]})
|
||||
assert updated is not None
|
||||
assert updated["schedule"]["kind"] == "interval"
|
||||
assert updated["schedule"]["minutes"] == 120
|
||||
assert updated["schedule_display"] == "every 120m"
|
||||
assert updated["next_run_at"] != old_next_run
|
||||
# Verify persisted to disk
|
||||
fetched = get_job(job["id"])
|
||||
assert fetched["schedule"]["minutes"] == 120
|
||||
assert fetched["schedule_display"] == "every 120m"
|
||||
|
||||
def test_update_enable_disable(self, tmp_cron_dir):
|
||||
job = create_job(prompt="Toggle me", schedule="every 1h")
|
||||
@@ -255,6 +261,26 @@ class TestUpdateJob:
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestPauseResumeJob:
|
||||
def test_pause_sets_state(self, tmp_cron_dir):
|
||||
job = create_job(prompt="Pause me", schedule="every 1h")
|
||||
paused = pause_job(job["id"], reason="user paused")
|
||||
assert paused is not None
|
||||
assert paused["enabled"] is False
|
||||
assert paused["state"] == "paused"
|
||||
assert paused["paused_reason"] == "user paused"
|
||||
|
||||
def test_resume_reenables_job(self, tmp_cron_dir):
|
||||
job = create_job(prompt="Resume me", schedule="every 1h")
|
||||
pause_job(job["id"], reason="user paused")
|
||||
resumed = resume_job(job["id"])
|
||||
assert resumed is not None
|
||||
assert resumed["enabled"] is True
|
||||
assert resumed["state"] == "scheduled"
|
||||
assert resumed["paused_at"] is None
|
||||
assert resumed["paused_reason"] is None
|
||||
|
||||
|
||||
class TestMarkJobRun:
|
||||
def test_increments_completed(self, tmp_cron_dir):
|
||||
job = create_job(prompt="Test", schedule="every 1h")
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from unittest.mock import patch, MagicMock
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from cron.scheduler import _resolve_origin, _deliver_result, run_job
|
||||
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, run_job
|
||||
|
||||
|
||||
class TestResolveOrigin:
|
||||
@@ -44,6 +45,56 @@ class TestResolveOrigin:
|
||||
assert _resolve_origin(job) is None
|
||||
|
||||
|
||||
class TestResolveDeliveryTarget:
|
||||
def test_origin_delivery_preserves_thread_id(self):
|
||||
job = {
|
||||
"deliver": "origin",
|
||||
"origin": {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-1001",
|
||||
"thread_id": "17585",
|
||||
},
|
||||
}
|
||||
|
||||
assert _resolve_delivery_target(job) == {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-1001",
|
||||
"thread_id": "17585",
|
||||
}
|
||||
|
||||
def test_bare_platform_uses_matching_origin_chat(self):
|
||||
job = {
|
||||
"deliver": "telegram",
|
||||
"origin": {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-1001",
|
||||
"thread_id": "17585",
|
||||
},
|
||||
}
|
||||
|
||||
assert _resolve_delivery_target(job) == {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-1001",
|
||||
"thread_id": "17585",
|
||||
}
|
||||
|
||||
def test_bare_platform_falls_back_to_home_channel(self, monkeypatch):
|
||||
monkeypatch.setenv("TELEGRAM_HOME_CHANNEL", "-2002")
|
||||
job = {
|
||||
"deliver": "telegram",
|
||||
"origin": {
|
||||
"platform": "discord",
|
||||
"chat_id": "abc",
|
||||
},
|
||||
}
|
||||
|
||||
assert _resolve_delivery_target(job) == {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-2002",
|
||||
"thread_id": None,
|
||||
}
|
||||
|
||||
|
||||
class TestDeliverResultMirrorLogging:
|
||||
"""Verify that mirror_to_session failures are logged, not silently swallowed."""
|
||||
|
||||
@@ -57,7 +108,7 @@ class TestDeliverResultMirrorLogging:
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("asyncio.run", return_value=None), \
|
||||
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})), \
|
||||
patch("gateway.mirror.mirror_to_session", side_effect=ConnectionError("network down")):
|
||||
job = {
|
||||
"id": "test-job",
|
||||
@@ -90,9 +141,8 @@ class TestDeliverResultMirrorLogging:
|
||||
}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("tools.send_message_tool._send_to_platform", return_value={"success": True}) as send_mock, \
|
||||
patch("gateway.mirror.mirror_to_session") as mirror_mock, \
|
||||
patch("asyncio.run", side_effect=lambda coro: None):
|
||||
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})) as send_mock, \
|
||||
patch("gateway.mirror.mirror_to_session") as mirror_mock:
|
||||
_deliver_result(job, "hello")
|
||||
|
||||
send_mock.assert_called_once()
|
||||
@@ -146,6 +196,60 @@ class TestRunJobSessionPersistence:
|
||||
assert kwargs["session_id"].startswith("cron_test-job_")
|
||||
fake_db.close.assert_called_once()
|
||||
|
||||
def test_run_job_sets_auto_delivery_env_from_dotenv_home_channel(self, tmp_path, monkeypatch):
|
||||
job = {
|
||||
"id": "test-job",
|
||||
"name": "test",
|
||||
"prompt": "hello",
|
||||
"deliver": "telegram",
|
||||
}
|
||||
fake_db = MagicMock()
|
||||
seen = {}
|
||||
|
||||
(tmp_path / ".env").write_text("TELEGRAM_HOME_CHANNEL=-2002\n")
|
||||
monkeypatch.delenv("TELEGRAM_HOME_CHANNEL", raising=False)
|
||||
monkeypatch.delenv("HERMES_CRON_AUTO_DELIVER_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID", raising=False)
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def run_conversation(self, *args, **kwargs):
|
||||
seen["platform"] = os.getenv("HERMES_CRON_AUTO_DELIVER_PLATFORM")
|
||||
seen["chat_id"] = os.getenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID")
|
||||
seen["thread_id"] = os.getenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID")
|
||||
return {"final_response": "ok"}
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "***",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
), \
|
||||
patch("run_agent.AIAgent", FakeAgent):
|
||||
success, output, final_response, error = run_job(job)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert final_response == "ok"
|
||||
assert "ok" in output
|
||||
assert seen == {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-2002",
|
||||
"thread_id": None,
|
||||
}
|
||||
assert os.getenv("HERMES_CRON_AUTO_DELIVER_PLATFORM") is None
|
||||
assert os.getenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID") is None
|
||||
assert os.getenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID") is None
|
||||
fake_db.close.assert_called_once()
|
||||
|
||||
|
||||
class TestRunJobConfigLogging:
|
||||
"""Verify that config.yaml parse failures are logged, not silently swallowed."""
|
||||
@@ -203,3 +307,145 @@ class TestRunJobConfigLogging:
|
||||
|
||||
assert any("failed to parse prefill messages" in r.message for r in caplog.records), \
|
||||
f"Expected 'failed to parse prefill messages' warning in logs, got: {[r.message for r in caplog.records]}"
|
||||
|
||||
|
||||
class TestRunJobPerJobOverrides:
|
||||
def test_job_level_model_provider_and_base_url_overrides_are_used(self, tmp_path):
|
||||
config_yaml = tmp_path / "config.yaml"
|
||||
config_yaml.write_text(
|
||||
"model:\n"
|
||||
" default: gpt-5.4\n"
|
||||
" provider: openai-codex\n"
|
||||
" base_url: https://chatgpt.com/backend-api/codex\n"
|
||||
)
|
||||
|
||||
job = {
|
||||
"id": "briefing-job",
|
||||
"name": "briefing",
|
||||
"prompt": "hello",
|
||||
"model": "perplexity/sonar-pro",
|
||||
"provider": "custom",
|
||||
"base_url": "http://127.0.0.1:4000/v1",
|
||||
}
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_runtime = {
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "http://127.0.0.1:4000/v1",
|
||||
"api_key": "***",
|
||||
}
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("cron.scheduler._resolve_origin", return_value=None), \
|
||||
patch("dotenv.load_dotenv"), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch("hermes_cli.runtime_provider.resolve_runtime_provider", return_value=fake_runtime) as runtime_mock, \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run_conversation.return_value = {"final_response": "ok"}
|
||||
mock_agent_cls.return_value = mock_agent
|
||||
|
||||
success, output, final_response, error = run_job(job)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert final_response == "ok"
|
||||
assert "ok" in output
|
||||
runtime_mock.assert_called_once_with(
|
||||
requested="custom",
|
||||
explicit_base_url="http://127.0.0.1:4000/v1",
|
||||
)
|
||||
assert mock_agent_cls.call_args.kwargs["model"] == "perplexity/sonar-pro"
|
||||
fake_db.close.assert_called_once()
|
||||
|
||||
|
||||
class TestRunJobSkillBacked:
|
||||
def test_run_job_loads_skill_and_disables_recursive_cron_tools(self, tmp_path):
|
||||
job = {
|
||||
"id": "skill-job",
|
||||
"name": "skill test",
|
||||
"prompt": "Check the feeds and summarize anything new.",
|
||||
"skill": "blogwatcher",
|
||||
}
|
||||
|
||||
fake_db = MagicMock()
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("cron.scheduler._resolve_origin", return_value=None), \
|
||||
patch("dotenv.load_dotenv"), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "***",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
), \
|
||||
patch("tools.skills_tool.skill_view", return_value=json.dumps({"success": True, "content": "# Blogwatcher\nFollow this skill."})), \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run_conversation.return_value = {"final_response": "ok"}
|
||||
mock_agent_cls.return_value = mock_agent
|
||||
|
||||
success, output, final_response, error = run_job(job)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert final_response == "ok"
|
||||
|
||||
kwargs = mock_agent_cls.call_args.kwargs
|
||||
assert "cronjob" in (kwargs["disabled_toolsets"] or [])
|
||||
|
||||
prompt_arg = mock_agent.run_conversation.call_args.args[0]
|
||||
assert "blogwatcher" in prompt_arg
|
||||
assert "Follow this skill" in prompt_arg
|
||||
assert "Check the feeds and summarize anything new." in prompt_arg
|
||||
|
||||
def test_run_job_loads_multiple_skills_in_order(self, tmp_path):
|
||||
job = {
|
||||
"id": "multi-skill-job",
|
||||
"name": "multi skill test",
|
||||
"prompt": "Combine the results.",
|
||||
"skills": ["blogwatcher", "find-nearby"],
|
||||
}
|
||||
|
||||
fake_db = MagicMock()
|
||||
|
||||
def _skill_view(name):
|
||||
return json.dumps({"success": True, "content": f"# {name}\nInstructions for {name}."})
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("cron.scheduler._resolve_origin", return_value=None), \
|
||||
patch("dotenv.load_dotenv"), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "***",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
), \
|
||||
patch("tools.skills_tool.skill_view", side_effect=_skill_view) as skill_view_mock, \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run_conversation.return_value = {"final_response": "ok"}
|
||||
mock_agent_cls.return_value = mock_agent
|
||||
|
||||
success, output, final_response, error = run_job(job)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert final_response == "ok"
|
||||
assert skill_view_mock.call_count == 2
|
||||
assert [call.args[0] for call in skill_view_mock.call_args_list] == ["blogwatcher", "find-nearby"]
|
||||
|
||||
prompt_arg = mock_agent.run_conversation.call_args.args[0]
|
||||
assert prompt_arg.index("blogwatcher") < prompt_arg.index("find-nearby")
|
||||
assert "Instructions for blogwatcher." in prompt_arg
|
||||
assert "Instructions for find-nearby." in prompt_arg
|
||||
assert "Combine the results." in prompt_arg
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for the delivery routing module."""
|
||||
|
||||
from gateway.config import Platform, GatewayConfig, PlatformConfig, HomeChannel
|
||||
from gateway.delivery import DeliveryTarget, parse_deliver_spec
|
||||
from gateway.delivery import DeliveryRouter, DeliveryTarget, parse_deliver_spec
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
@@ -85,3 +85,12 @@ class TestTargetToStringRoundtrip:
|
||||
reparsed = DeliveryTarget.parse(s)
|
||||
assert reparsed.platform == Platform.TELEGRAM
|
||||
assert reparsed.chat_id == "999"
|
||||
|
||||
|
||||
class TestDeliveryRouter:
|
||||
def test_resolve_targets_does_not_duplicate_local_when_explicit(self):
|
||||
router = DeliveryRouter(GatewayConfig(always_log_local=True))
|
||||
|
||||
targets = router.resolve_targets(["local"])
|
||||
|
||||
assert [target.platform for target in targets] == [Platform.LOCAL]
|
||||
|
||||
@@ -252,3 +252,109 @@ async def test_discord_dms_ignore_mention_requirement(adapter, monkeypatch):
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "dm without mention"
|
||||
assert event.source.chat_type == "dm"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_enabled_by_default(adapter, monkeypatch):
|
||||
"""Auto-threading should be enabled by default (DISCORD_AUTO_THREAD defaults to 'true')."""
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
# Patch _auto_create_thread to return a fake thread
|
||||
fake_thread = FakeThread(channel_id=999, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_awaited_once()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "thread"
|
||||
assert event.source.thread_id == "999"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_can_be_disabled(adapter, monkeypatch):
|
||||
"""Setting auto_thread to false skips thread creation."""
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch):
|
||||
"""Messages in a thread the bot has participated in should not require @mention."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
# Simulate bot having previously participated in thread 456
|
||||
adapter._bot_participated_threads.add("456")
|
||||
|
||||
thread = FakeThread(channel_id=456, name="existing thread")
|
||||
message = make_message(channel=thread, content="follow-up without mention")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "follow-up without mention"
|
||||
assert event.source.chat_type == "thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_unknown_thread_still_requires_mention(adapter, monkeypatch):
|
||||
"""Messages in a thread the bot hasn't participated in should still require @mention."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
# Bot has NOT participated in thread 789
|
||||
thread = FakeThread(channel_id=789, name="some thread")
|
||||
message = make_message(channel=thread, content="hello from unknown thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch):
|
||||
"""Auto-created threads should be tracked for future mention-free replies."""
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
fake_thread = FakeThread(channel_id=555, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="start a thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "555" in adapter._bot_participated_threads
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeypatch):
|
||||
"""When the bot processes a message in a thread, it tracks participation."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
thread = FakeThread(channel_id=777, name="manually created thread")
|
||||
message = make_message(channel=thread, content="hello in thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "777" in adapter._bot_participated_threads
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_without_reference_when_reply_target_is_system_message():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
ref_msg = SimpleNamespace(id=99)
|
||||
sent_msg = SimpleNamespace(id=1234)
|
||||
send_calls = []
|
||||
|
||||
async def fake_send(*, content, reference=None):
|
||||
send_calls.append({"content": content, "reference": reference})
|
||||
if len(send_calls) == 1:
|
||||
raise RuntimeError(
|
||||
"400 Bad Request (error code: 50035): Invalid Form Body\n"
|
||||
"In message_reference: Cannot reply to a system message"
|
||||
)
|
||||
return sent_msg
|
||||
|
||||
channel = SimpleNamespace(
|
||||
fetch_message=AsyncMock(return_value=ref_msg),
|
||||
send=AsyncMock(side_effect=fake_send),
|
||||
)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("555", "hello", reply_to="99")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "1234"
|
||||
assert channel.fetch_message.await_count == 1
|
||||
assert channel.send.await_count == 2
|
||||
assert send_calls[0]["reference"] is ref_msg
|
||||
assert send_calls[1]["reference"] is None
|
||||
@@ -363,11 +363,37 @@ async def test_auto_thread_creates_thread_and_redirects(adapter, monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_disabled_by_default(adapter, monkeypatch):
|
||||
"""Without DISCORD_AUTO_THREAD, messages stay in the channel."""
|
||||
async def test_auto_thread_enabled_by_default_slash_commands(adapter, monkeypatch):
|
||||
"""Without DISCORD_AUTO_THREAD env var, auto-threading is enabled (default: true)."""
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
fake_thread = _FakeThreadChannel(channel_id=999, name="auto-thread")
|
||||
adapter._auto_create_thread = AsyncMock(return_value=fake_thread)
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def capture_handle(event):
|
||||
captured_events.append(event)
|
||||
|
||||
adapter.handle_message = capture_handle
|
||||
|
||||
msg = _fake_message(_FakeTextChannel())
|
||||
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
adapter._auto_create_thread.assert_awaited_once()
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].source.chat_id == "999" # redirected to thread
|
||||
assert captured_events[0].source.chat_type == "thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_can_be_disabled(adapter, monkeypatch):
|
||||
"""Setting DISCORD_AUTO_THREAD=false keeps messages in the channel."""
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
captured_events = []
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _source(chat_id="123456", chat_type="dm"):
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||
adapter = StubAdapter()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def block_forever(_event):
|
||||
await release.wait()
|
||||
return None
|
||||
|
||||
adapter.set_message_handler(block_forever)
|
||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
session_key = build_session_key(event.source)
|
||||
assert session_key in adapter._active_sessions
|
||||
assert adapter._background_tasks
|
||||
|
||||
await adapter.cancel_background_tasks()
|
||||
|
||||
assert adapter._background_tasks == set()
|
||||
assert adapter._active_sessions == {}
|
||||
assert adapter._pending_messages == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._pending_messages = {"session": "pending text"}
|
||||
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
|
||||
adapter = StubAdapter()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def block_forever(_event):
|
||||
await release.wait()
|
||||
return None
|
||||
|
||||
adapter.set_message_handler(block_forever)
|
||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
disconnect_mock = AsyncMock()
|
||||
adapter.disconnect = disconnect_mock
|
||||
|
||||
session_key = build_session_key(event.source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {session_key: running_agent}
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
|
||||
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
||||
disconnect_mock.assert_awaited_once()
|
||||
assert runner.adapters == {}
|
||||
assert runner._running_agents == {}
|
||||
assert runner._pending_messages == {}
|
||||
assert runner._pending_approvals == {}
|
||||
assert runner._shutdown_event.is_set() is True
|
||||
@@ -0,0 +1,25 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_enrichment_uses_athabasca_upload_guidance_without_stale_r2_warning():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
||||
with patch(
|
||||
"tools.vision_tools.vision_analyze_tool",
|
||||
return_value='{"success": true, "analysis": "A painted serpent warrior."}',
|
||||
):
|
||||
enriched = await runner._enrich_message_with_vision(
|
||||
"caption",
|
||||
["/tmp/test.jpg"],
|
||||
)
|
||||
|
||||
assert "R2 not configured" not in enriched
|
||||
assert "Gateway media URL available for reference" not in enriched
|
||||
assert "POST /api/uploads" in enriched
|
||||
assert "Do not store the local cache path" in enriched
|
||||
assert "caption" in enriched
|
||||
@@ -11,7 +11,7 @@ import asyncio
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
@@ -50,11 +50,11 @@ class TestInterruptKeyConsistency:
|
||||
"""Ensure adapter interrupt methods are queried with session_key, not chat_id."""
|
||||
|
||||
def test_session_key_differs_from_chat_id_for_dm(self):
|
||||
"""Session key for a DM is NOT the same as chat_id."""
|
||||
"""Session key for a DM is namespaced and includes the DM chat_id."""
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
assert session_key != source.chat_id
|
||||
assert session_key == "agent:main:telegram:dm"
|
||||
assert session_key == "agent:main:telegram:dm:123456"
|
||||
|
||||
def test_session_key_differs_from_chat_id_for_group(self):
|
||||
"""Session key for a group chat includes prefix, unlike raw chat_id."""
|
||||
@@ -122,3 +122,29 @@ class TestInterruptKeyConsistency:
|
||||
|
||||
# Interrupt event was set
|
||||
assert adapter._active_sessions[session_key].is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_photo_followup_is_queued_without_interrupt(self):
|
||||
"""Photo follow-ups should queue behind the active run instead of interrupting it."""
|
||||
adapter = StubAdapter()
|
||||
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
|
||||
|
||||
source = _source("-1001234", "group")
|
||||
session_key = build_session_key(source)
|
||||
interrupt_event = asyncio.Event()
|
||||
adapter._active_sessions[session_key] = interrupt_event
|
||||
|
||||
event = MessageEvent(
|
||||
text="caption",
|
||||
source=source,
|
||||
message_type=MessageType.PHOTO,
|
||||
message_id="2",
|
||||
media_urls=["/tmp/photo-a.jpg"],
|
||||
media_types=["image/jpeg"],
|
||||
)
|
||||
await adapter.handle_message(event)
|
||||
|
||||
queued = adapter._pending_messages[session_key]
|
||||
assert queued is event
|
||||
assert queued.media_urls == ["/tmp/photo-a.jpg"]
|
||||
assert interrupt_event.is_set() is False
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Tests for the /plan gateway slash command."""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.skill_commands import scan_skill_commands
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = SessionEntry(
|
||||
session_key="agent:main:telegram:dm:c1:u1",
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = True
|
||||
runner.session_store.append_to_transcript = MagicMock()
|
||||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
runner._run_agent = AsyncMock(
|
||||
return_value={
|
||||
"final_response": "planned",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 0,
|
||||
}
|
||||
)
|
||||
return runner
|
||||
|
||||
|
||||
def _make_event(text="/plan"):
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
def _make_plan_skill(skills_dir):
|
||||
skill_dir = skills_dir / "plan"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"""---
|
||||
name: plan
|
||||
description: Plan mode skill.
|
||||
---
|
||||
|
||||
# Plan
|
||||
|
||||
Use the current conversation context when no explicit instruction is provided.
|
||||
Save plans under the active workspace's .hermes/plans directory.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class TestGatewayPlanCommand:
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_command_loads_skill_and_runs_agent(self, monkeypatch, tmp_path):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
runner = _make_runner()
|
||||
event = _make_event("/plan Add OAuth login")
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100_000,
|
||||
)
|
||||
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_plan_skill(tmp_path)
|
||||
scan_skill_commands()
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result == "planned"
|
||||
forwarded = runner._run_agent.call_args.kwargs["message"]
|
||||
assert "Plan mode skill" in forwarded
|
||||
assert "Add OAuth login" in forwarded
|
||||
assert ".hermes/plans" in forwarded
|
||||
assert str(tmp_path / "plans") not in forwarded
|
||||
assert "active workspace/backend cwd" in forwarded
|
||||
assert "Runtime note:" in forwarded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_command_appears_in_help_output_via_skill_listing(self, tmp_path):
|
||||
runner = _make_runner()
|
||||
event = _make_event("/help")
|
||||
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_plan_skill(tmp_path)
|
||||
scan_skill_commands()
|
||||
result = await runner._handle_help_command(event)
|
||||
|
||||
assert "/plan" in result
|
||||
@@ -0,0 +1,97 @@
|
||||
"""Regression tests for /retry replacement semantics."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionStore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_retry_replaces_last_user_turn_in_transcript(tmp_path):
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = None
|
||||
store._loaded = True
|
||||
|
||||
session_id = "retry_session"
|
||||
for msg in [
|
||||
{"role": "session_meta", "tools": []},
|
||||
{"role": "user", "content": "first question"},
|
||||
{"role": "assistant", "content": "first answer"},
|
||||
{"role": "user", "content": "retry me"},
|
||||
{"role": "assistant", "content": "old answer"},
|
||||
]:
|
||||
store.append_to_transcript(session_id, msg)
|
||||
|
||||
gw = GatewayRunner.__new__(GatewayRunner)
|
||||
gw.config = config
|
||||
gw.session_store = store
|
||||
|
||||
session_entry = MagicMock(session_id=session_id)
|
||||
session_entry.last_prompt_tokens = 111
|
||||
gw.session_store.get_or_create_session = MagicMock(return_value=session_entry)
|
||||
|
||||
async def fake_handle_message(event):
|
||||
assert event.text == "retry me"
|
||||
transcript_before = store.load_transcript(session_id)
|
||||
assert [m.get("content") for m in transcript_before if m.get("role") == "user"] == [
|
||||
"first question"
|
||||
]
|
||||
store.append_to_transcript(session_id, {"role": "user", "content": event.text})
|
||||
store.append_to_transcript(session_id, {"role": "assistant", "content": "new answer"})
|
||||
return "new answer"
|
||||
|
||||
gw._handle_message = AsyncMock(side_effect=fake_handle_message)
|
||||
|
||||
result = await gw._handle_retry_command(
|
||||
MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock())
|
||||
)
|
||||
|
||||
assert result == "new answer"
|
||||
transcript_after = store.load_transcript(session_id)
|
||||
assert [m.get("content") for m in transcript_after if m.get("role") == "user"] == [
|
||||
"first question",
|
||||
"retry me",
|
||||
]
|
||||
assert [m.get("content") for m in transcript_after if m.get("role") == "assistant"] == [
|
||||
"first answer",
|
||||
"new answer",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_retry_replays_original_text_not_retry_command(tmp_path):
|
||||
config = MagicMock()
|
||||
config.sessions_dir = tmp_path
|
||||
config.max_context_messages = 20
|
||||
gw = GatewayRunner.__new__(GatewayRunner)
|
||||
gw.config = config
|
||||
gw.session_store = MagicMock()
|
||||
|
||||
session_entry = MagicMock(session_id="test-session")
|
||||
session_entry.last_prompt_tokens = 55
|
||||
gw.session_store.get_or_create_session.return_value = session_entry
|
||||
gw.session_store.load_transcript.return_value = [
|
||||
{"role": "user", "content": "real message"},
|
||||
{"role": "assistant", "content": "answer"},
|
||||
]
|
||||
gw.session_store.rewrite_transcript = MagicMock()
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_handle_message(event):
|
||||
captured["text"] = event.text
|
||||
return "ok"
|
||||
|
||||
gw._handle_message = AsyncMock(side_effect=fake_handle_message)
|
||||
|
||||
await gw._handle_retry_command(
|
||||
MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock())
|
||||
)
|
||||
|
||||
assert captured["text"] == "real message"
|
||||
@@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
class _FatalAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="token"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
self._set_fatal_error(
|
||||
"telegram_token_lock",
|
||||
"Another local Hermes gateway is already using this Telegram bot token.",
|
||||
retryable=False,
|
||||
)
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_requests_clean_exit_for_nonretryable_startup_conflict(monkeypatch, tmp_path):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="token")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _FatalAdapter())
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
assert runner.should_exit_cleanly is True
|
||||
assert "already using this Telegram bot token" in runner.exit_reason
|
||||
@@ -199,6 +199,57 @@ class TestDiscordSendImageFile:
|
||||
assert result.message_id == "99"
|
||||
mock_channel.send.assert_awaited_once()
|
||||
|
||||
def test_send_document_uploads_file_attachment(self, adapter, tmp_path):
|
||||
"""send_document should upload a native Discord attachment."""
|
||||
pdf = tmp_path / "sample.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n")
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = 100
|
||||
mock_channel.send = AsyncMock(return_value=mock_msg)
|
||||
adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
||||
with patch.object(discord_mod_ref, "File", MagicMock()) as file_cls:
|
||||
result = _run(
|
||||
adapter.send_document(
|
||||
chat_id="67890",
|
||||
file_path=str(pdf),
|
||||
file_name="renamed.pdf",
|
||||
metadata={"thread_id": "123"},
|
||||
)
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.message_id == "100"
|
||||
assert "file" in mock_channel.send.call_args.kwargs
|
||||
assert file_cls.call_args.kwargs["filename"] == "renamed.pdf"
|
||||
|
||||
def test_send_video_uploads_file_attachment(self, adapter, tmp_path):
|
||||
"""send_video should upload a native Discord attachment."""
|
||||
video = tmp_path / "clip.mp4"
|
||||
video.write_bytes(b"\x00\x00\x00\x18ftypmp42" + b"\x00" * 50)
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = 101
|
||||
mock_channel.send = AsyncMock(return_value=mock_msg)
|
||||
adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
||||
with patch.object(discord_mod_ref, "File", MagicMock()) as file_cls:
|
||||
result = _run(
|
||||
adapter.send_video(
|
||||
chat_id="67890",
|
||||
video_path=str(video),
|
||||
metadata={"thread_id": "123"},
|
||||
)
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.message_id == "101"
|
||||
assert "file" in mock_channel.send.call_args.kwargs
|
||||
assert file_cls.call_args.kwargs["filename"] == "clip.mp4"
|
||||
|
||||
def test_returns_error_when_file_missing(self, adapter):
|
||||
result = _run(
|
||||
adapter.send_image_file(chat_id="67890", image_path="/nonexistent.png")
|
||||
|
||||
@@ -338,7 +338,7 @@ class TestSessionStoreRewriteTranscript:
|
||||
|
||||
class TestWhatsAppDMSessionKeyConsistency:
|
||||
"""Regression: all session-key construction must go through build_session_key
|
||||
so WhatsApp DMs include chat_id while other DMs do not."""
|
||||
so DMs are isolated by chat_id across platforms."""
|
||||
|
||||
@pytest.fixture()
|
||||
def store(self, tmp_path):
|
||||
@@ -369,15 +369,24 @@ class TestWhatsAppDMSessionKeyConsistency:
|
||||
)
|
||||
assert store._generate_session_key(source) == build_session_key(source)
|
||||
|
||||
def test_telegram_dm_omits_chat_id(self):
|
||||
"""Non-WhatsApp DMs should still omit chat_id (single owner DM)."""
|
||||
def test_telegram_dm_includes_chat_id(self):
|
||||
"""Non-WhatsApp DMs should also include chat_id to separate users."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="99",
|
||||
chat_type="dm",
|
||||
)
|
||||
key = build_session_key(source)
|
||||
assert key == "agent:main:telegram:dm"
|
||||
assert key == "agent:main:telegram:dm:99"
|
||||
|
||||
def test_distinct_dm_chat_ids_get_distinct_session_keys(self):
|
||||
"""Different DM chats must not collapse into one shared session."""
|
||||
first = SessionSource(platform=Platform.TELEGRAM, chat_id="99", chat_type="dm")
|
||||
second = SessionSource(platform=Platform.TELEGRAM, chat_id="100", chat_type="dm")
|
||||
|
||||
assert build_session_key(first) == "agent:main:telegram:dm:99"
|
||||
assert build_session_key(second) == "agent:main:telegram:dm:100"
|
||||
assert build_session_key(first) != build_session_key(second)
|
||||
|
||||
def test_discord_group_includes_chat_id(self):
|
||||
"""Group/channel keys include chat_type and chat_id."""
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionContext, SessionSource
|
||||
|
||||
|
||||
def test_set_session_env_includes_thread_id(monkeypatch):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
runner._set_session_env(context)
|
||||
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert os.getenv("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert os.getenv("HERMES_SESSION_CHAT_NAME") == "Group"
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") == "17585"
|
||||
|
||||
|
||||
def test_clear_session_env_removes_thread_id(monkeypatch):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
||||
monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "-1001")
|
||||
monkeypatch.setenv("HERMES_SESSION_CHAT_NAME", "Group")
|
||||
monkeypatch.setenv("HERMES_SESSION_THREAD_ID", "17585")
|
||||
|
||||
runner._clear_session_env()
|
||||
|
||||
assert os.getenv("HERMES_SESSION_PLATFORM") is None
|
||||
assert os.getenv("HERMES_SESSION_CHAT_ID") is None
|
||||
assert os.getenv("HERMES_SESSION_CHAT_NAME") is None
|
||||
assert os.getenv("HERMES_SESSION_THREAD_ID") is None
|
||||
@@ -25,3 +25,77 @@ class TestGatewayPidState:
|
||||
|
||||
assert status.get_running_pid() is None
|
||||
assert not pid_path.exists()
|
||||
|
||||
|
||||
class TestGatewayRuntimeStatus:
|
||||
def test_write_runtime_status_records_platform_failure(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="startup_failed",
|
||||
exit_reason="telegram conflict",
|
||||
platform="telegram",
|
||||
platform_state="fatal",
|
||||
error_code="telegram_polling_conflict",
|
||||
error_message="another poller is active",
|
||||
)
|
||||
|
||||
payload = status.read_runtime_status()
|
||||
assert payload["gateway_state"] == "startup_failed"
|
||||
assert payload["exit_reason"] == "telegram conflict"
|
||||
assert payload["platforms"]["telegram"]["state"] == "fatal"
|
||||
assert payload["platforms"]["telegram"]["error_code"] == "telegram_polling_conflict"
|
||||
assert payload["platforms"]["telegram"]["error_message"] == "another poller is active"
|
||||
|
||||
|
||||
class TestScopedLocks:
|
||||
def test_acquire_scoped_lock_rejects_live_other_process(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
|
||||
lock_path = tmp_path / "locks" / "telegram-bot-token-2bb80d537b1da3e3.lock"
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
lock_path.write_text(json.dumps({
|
||||
"pid": 99999,
|
||||
"start_time": 123,
|
||||
"kind": "hermes-gateway",
|
||||
}))
|
||||
|
||||
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
|
||||
|
||||
acquired, existing = status.acquire_scoped_lock("telegram-bot-token", "secret", metadata={"platform": "telegram"})
|
||||
|
||||
assert acquired is False
|
||||
assert existing["pid"] == 99999
|
||||
|
||||
def test_acquire_scoped_lock_replaces_stale_record(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
|
||||
lock_path = tmp_path / "locks" / "telegram-bot-token-2bb80d537b1da3e3.lock"
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
lock_path.write_text(json.dumps({
|
||||
"pid": 99999,
|
||||
"start_time": 123,
|
||||
"kind": "hermes-gateway",
|
||||
}))
|
||||
|
||||
def fake_kill(pid, sig):
|
||||
raise ProcessLookupError
|
||||
|
||||
monkeypatch.setattr(status.os, "kill", fake_kill)
|
||||
|
||||
acquired, existing = status.acquire_scoped_lock("telegram-bot-token", "secret", metadata={"platform": "telegram"})
|
||||
|
||||
assert acquired is True
|
||||
payload = json.loads(lock_path.read_text())
|
||||
assert payload["pid"] == os.getpid()
|
||||
assert payload["metadata"]["platform"] == "telegram"
|
||||
|
||||
def test_release_scoped_lock_only_removes_current_owner(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_GATEWAY_LOCK_DIR", str(tmp_path / "locks"))
|
||||
|
||||
acquired, _ = status.acquire_scoped_lock("telegram-bot-token", "secret", metadata={"platform": "telegram"})
|
||||
assert acquired is True
|
||||
lock_path = tmp_path / "locks" / "telegram-bot-token-2bb80d537b1da3e3.lock"
|
||||
assert lock_path.exists()
|
||||
|
||||
status.release_scoped_lock("telegram-bot-token", "secret")
|
||||
assert not lock_path.exists()
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Gateway STT config tests — honor stt.enabled: false from config.yaml."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from gateway.config import GatewayConfig, load_gateway_config
|
||||
|
||||
|
||||
def test_gateway_config_stt_disabled_from_dict_nested():
|
||||
config = GatewayConfig.from_dict({"stt": {"enabled": False}})
|
||||
assert config.stt_enabled is False
|
||||
|
||||
|
||||
def test_load_gateway_config_bridges_stt_enabled_from_config_yaml(tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
yaml.dump({"stt": {"enabled": False}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config.stt_enabled is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_message_with_transcription_skips_when_stt_disabled():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(stt_enabled=False)
|
||||
|
||||
with patch(
|
||||
"tools.transcription_tools.transcribe_audio",
|
||||
side_effect=AssertionError("transcribe_audio should not be called when STT is disabled"),
|
||||
), patch(
|
||||
"tools.transcription_tools.get_stt_model_from_config",
|
||||
return_value=None,
|
||||
):
|
||||
result = await runner._enrich_message_with_transcription(
|
||||
"caption",
|
||||
["/tmp/voice.ogg"],
|
||||
)
|
||||
|
||||
assert "transcription is disabled" in result.lower()
|
||||
assert "caption" in result
|
||||
@@ -0,0 +1,124 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.constants.ChatType.GROUP = "group"
|
||||
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_rejects_same_host_token_lock(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="secret-token"))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.acquire_scoped_lock",
|
||||
lambda scope, identity, metadata=None: (False, {"pid": 4242}),
|
||||
)
|
||||
|
||||
ok = await adapter.connect()
|
||||
|
||||
assert ok is False
|
||||
assert adapter.fatal_error_code == "telegram_token_lock"
|
||||
assert adapter.has_fatal_error is True
|
||||
assert "already using this Telegram bot token" in adapter.fatal_error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_polling_conflict_stops_polling_and_notifies_handler(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="secret-token"))
|
||||
fatal_handler = AsyncMock()
|
||||
adapter.set_fatal_error_handler(fatal_handler)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.acquire_scoped_lock",
|
||||
lambda scope, identity, metadata=None: (True, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.release_scoped_lock",
|
||||
lambda scope, identity: None,
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_start_polling(**kwargs):
|
||||
captured["error_callback"] = kwargs["error_callback"]
|
||||
|
||||
updater = SimpleNamespace(
|
||||
start_polling=AsyncMock(side_effect=fake_start_polling),
|
||||
stop=AsyncMock(),
|
||||
)
|
||||
bot = SimpleNamespace(set_my_commands=AsyncMock())
|
||||
app = SimpleNamespace(
|
||||
bot=bot,
|
||||
updater=updater,
|
||||
add_handler=MagicMock(),
|
||||
initialize=AsyncMock(),
|
||||
start=AsyncMock(),
|
||||
)
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
|
||||
|
||||
ok = await adapter.connect()
|
||||
|
||||
assert ok is True
|
||||
assert callable(captured["error_callback"])
|
||||
|
||||
conflict = type("Conflict", (Exception,), {})
|
||||
captured["error_callback"](conflict("Conflict: terminated by other getUpdates request; make sure that only one bot instance is running"))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert adapter.fatal_error_code == "telegram_polling_conflict"
|
||||
assert adapter.has_fatal_error is True
|
||||
updater.stop.assert_awaited()
|
||||
fatal_handler.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_skips_inactive_updater_and_app(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
updater = SimpleNamespace(running=False, stop=AsyncMock())
|
||||
app = SimpleNamespace(
|
||||
updater=updater,
|
||||
running=False,
|
||||
stop=AsyncMock(),
|
||||
shutdown=AsyncMock(),
|
||||
)
|
||||
adapter._app = app
|
||||
|
||||
warning = MagicMock()
|
||||
monkeypatch.setattr("gateway.platforms.telegram.logger.warning", warning)
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
updater.stop.assert_not_awaited()
|
||||
app.stop.assert_not_awaited()
|
||||
app.shutdown.assert_awaited_once()
|
||||
warning.assert_not_called()
|
||||
@@ -12,6 +12,7 @@ import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -81,20 +82,21 @@ def _make_document(
|
||||
return doc
|
||||
|
||||
|
||||
def _make_message(document=None, caption=None):
|
||||
"""Build a mock Telegram Message with the given document."""
|
||||
def _make_message(document=None, caption=None, media_group_id=None, photo=None):
|
||||
"""Build a mock Telegram Message with the given document/photo."""
|
||||
msg = MagicMock()
|
||||
msg.message_id = 42
|
||||
msg.text = caption or ""
|
||||
msg.caption = caption
|
||||
msg.date = None
|
||||
# Media flags — all None except document
|
||||
msg.photo = None
|
||||
# Media flags — all None except explicit payload
|
||||
msg.photo = photo
|
||||
msg.video = None
|
||||
msg.audio = None
|
||||
msg.voice = None
|
||||
msg.sticker = None
|
||||
msg.document = document
|
||||
msg.media_group_id = media_group_id
|
||||
# Chat / user
|
||||
msg.chat = MagicMock()
|
||||
msg.chat.id = 100
|
||||
@@ -165,6 +167,12 @@ class TestDocumentTypeDetection:
|
||||
# TestDocumentDownloadBlock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_photo(file_obj=None):
|
||||
photo = MagicMock()
|
||||
photo.get_file = AsyncMock(return_value=file_obj or _make_file_obj(b"photo-bytes"))
|
||||
return photo
|
||||
|
||||
|
||||
class TestDocumentDownloadBlock:
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_pdf_is_cached(self, adapter):
|
||||
@@ -339,6 +347,70 @@ class TestDocumentDownloadBlock:
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMediaGroups — media group (album) buffering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMediaGroups:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_album_photo_burst_is_buffered_and_combined(self, adapter):
|
||||
first_photo = _make_photo(_make_file_obj(b"first"))
|
||||
second_photo = _make_photo(_make_file_obj(b"second"))
|
||||
|
||||
msg1 = _make_message(caption="two images", photo=[first_photo])
|
||||
msg2 = _make_message(photo=[second_photo])
|
||||
|
||||
with patch("gateway.platforms.telegram.cache_image_from_bytes", side_effect=["/tmp/burst-one.jpg", "/tmp/burst-two.jpg"]):
|
||||
await adapter._handle_media_message(_make_update(msg1), MagicMock())
|
||||
await adapter._handle_media_message(_make_update(msg2), MagicMock())
|
||||
assert adapter.handle_message.await_count == 0
|
||||
await asyncio.sleep(adapter.MEDIA_GROUP_WAIT_SECONDS + 0.05)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "two images"
|
||||
assert event.media_urls == ["/tmp/burst-one.jpg", "/tmp/burst-two.jpg"]
|
||||
assert len(event.media_types) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_photo_album_is_buffered_and_combined(self, adapter):
|
||||
first_photo = _make_photo(_make_file_obj(b"first"))
|
||||
second_photo = _make_photo(_make_file_obj(b"second"))
|
||||
|
||||
msg1 = _make_message(caption="two images", media_group_id="album-1", photo=[first_photo])
|
||||
msg2 = _make_message(media_group_id="album-1", photo=[second_photo])
|
||||
|
||||
with patch("gateway.platforms.telegram.cache_image_from_bytes", side_effect=["/tmp/one.jpg", "/tmp/two.jpg"]):
|
||||
await adapter._handle_media_message(_make_update(msg1), MagicMock())
|
||||
await adapter._handle_media_message(_make_update(msg2), MagicMock())
|
||||
assert adapter.handle_message.await_count == 0
|
||||
await asyncio.sleep(adapter.MEDIA_GROUP_WAIT_SECONDS + 0.05)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.text == "two images"
|
||||
assert event.media_urls == ["/tmp/one.jpg", "/tmp/two.jpg"]
|
||||
assert len(event.media_types) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cancels_pending_media_group_flush(self, adapter):
|
||||
first_photo = _make_photo(_make_file_obj(b"first"))
|
||||
msg = _make_message(caption="two images", media_group_id="album-2", photo=[first_photo])
|
||||
|
||||
with patch("gateway.platforms.telegram.cache_image_from_bytes", return_value="/tmp/one.jpg"):
|
||||
await adapter._handle_media_message(_make_update(msg), MagicMock())
|
||||
|
||||
assert "album-2" in adapter._media_group_events
|
||||
assert "album-2" in adapter._media_group_tasks
|
||||
|
||||
await adapter.disconnect()
|
||||
await asyncio.sleep(adapter.MEDIA_GROUP_WAIT_SECONDS + 0.05)
|
||||
|
||||
assert adapter._media_group_events == {}
|
||||
assert adapter._media_group_tasks == {}
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendDocument — outbound file attachment delivery
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -486,6 +558,51 @@ class TestSendDocument:
|
||||
assert call_kwargs["reply_to_message_id"] == 50
|
||||
|
||||
|
||||
class TestTelegramPhotoBatching:
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_photo_batch_does_not_drop_newer_scheduled_task(self, adapter):
|
||||
old_task = MagicMock()
|
||||
new_task = MagicMock()
|
||||
batch_key = "session:photo-burst"
|
||||
adapter._pending_photo_batch_tasks[batch_key] = new_task
|
||||
adapter._pending_photo_batches[batch_key] = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.PHOTO,
|
||||
source=SimpleNamespace(channel_id="chat-1"),
|
||||
media_urls=["/tmp/a.jpg"],
|
||||
media_types=["image/jpeg"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("gateway.platforms.telegram.asyncio.current_task", return_value=old_task),
|
||||
patch("gateway.platforms.telegram.asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
await adapter._flush_photo_batch(batch_key)
|
||||
|
||||
assert adapter._pending_photo_batch_tasks[batch_key] is new_task
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cancels_pending_photo_batch_tasks(self, adapter):
|
||||
task = MagicMock()
|
||||
task.done.return_value = False
|
||||
adapter._pending_photo_batch_tasks["session:photo-burst"] = task
|
||||
adapter._pending_photo_batches["session:photo-burst"] = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.PHOTO,
|
||||
source=SimpleNamespace(channel_id="chat-1"),
|
||||
)
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.updater.stop = AsyncMock()
|
||||
adapter._app.stop = AsyncMock()
|
||||
adapter._app.shutdown = AsyncMock()
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
task.cancel.assert_called_once()
|
||||
assert adapter._pending_photo_batch_tasks == {}
|
||||
assert adapter._pending_photo_batches == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendVideo — outbound video delivery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
class _PendingAdapter:
|
||||
def __init__(self):
|
||||
self._pending_messages = {}
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
||||
runner.adapters = {Platform.TELEGRAM: _PendingAdapter()}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._voice_mode = {}
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_does_not_priority_interrupt_photo_followup():
|
||||
runner = _make_runner()
|
||||
source = SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm")
|
||||
session_key = build_session_key(source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents[session_key] = running_agent
|
||||
|
||||
event = MessageEvent(
|
||||
text="caption",
|
||||
message_type=MessageType.PHOTO,
|
||||
source=source,
|
||||
media_urls=["/tmp/photo-a.jpg"],
|
||||
media_types=["image/jpeg"],
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result is None
|
||||
running_agent.interrupt.assert_not_called()
|
||||
assert runner.adapters[Platform.TELEGRAM]._pending_messages[session_key] is event
|
||||
@@ -88,7 +88,7 @@ class TestHandleUpdateCommand:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_hermes_binary(self, tmp_path):
|
||||
"""Returns error when hermes is not on PATH."""
|
||||
"""Returns error when hermes is not on PATH and hermes_cli is not importable."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
|
||||
@@ -102,10 +102,77 @@ class TestHandleUpdateCommand:
|
||||
|
||||
with patch("gateway.run._hermes_home", tmp_path), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", return_value=None):
|
||||
patch("shutil.which", return_value=None), \
|
||||
patch("importlib.util.find_spec", return_value=None):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
assert "not found on PATH" in result
|
||||
assert "Could not locate" in result
|
||||
assert "hermes update" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_to_sys_executable(self, tmp_path):
|
||||
"""Falls back to sys.executable -m hermes_cli.main when hermes not on PATH."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
|
||||
fake_root = tmp_path / "project"
|
||||
fake_root.mkdir()
|
||||
(fake_root / ".git").mkdir()
|
||||
(fake_root / "gateway").mkdir()
|
||||
(fake_root / "gateway" / "run.py").touch()
|
||||
fake_file = str(fake_root / "gateway" / "run.py")
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
mock_popen = MagicMock()
|
||||
fake_spec = MagicMock()
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", return_value=None), \
|
||||
patch("importlib.util.find_spec", return_value=fake_spec), \
|
||||
patch("subprocess.Popen", mock_popen):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
assert "Starting Hermes update" in result
|
||||
call_args = mock_popen.call_args[0][0]
|
||||
# The update_cmd uses sys.executable -m hermes_cli.main
|
||||
joined = " ".join(call_args) if isinstance(call_args, list) else call_args
|
||||
assert "hermes_cli.main" in joined or "bash" in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_hermes_bin_prefers_which(self, tmp_path):
|
||||
"""_resolve_hermes_bin returns argv parts from shutil.which when available."""
|
||||
from gateway.run import _resolve_hermes_bin
|
||||
|
||||
with patch("shutil.which", return_value="/custom/path/hermes"):
|
||||
result = _resolve_hermes_bin()
|
||||
|
||||
assert result == ["/custom/path/hermes"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_hermes_bin_fallback(self):
|
||||
"""_resolve_hermes_bin falls back to sys.executable argv when which fails."""
|
||||
import sys
|
||||
from gateway.run import _resolve_hermes_bin
|
||||
|
||||
fake_spec = MagicMock()
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("importlib.util.find_spec", return_value=fake_spec):
|
||||
result = _resolve_hermes_bin()
|
||||
|
||||
assert result == [sys.executable, "-m", "hermes_cli.main"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_hermes_bin_returns_none_when_both_fail(self):
|
||||
"""_resolve_hermes_bin returns None when both strategies fail."""
|
||||
from gateway.run import _resolve_hermes_bin
|
||||
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("importlib.util.find_spec", return_value=None):
|
||||
result = _resolve_hermes_bin()
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_pending_marker(self, tmp_path):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for the /voice command and auto voice reply in the gateway."""
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
@@ -206,9 +207,11 @@ class TestAutoVoiceReply:
|
||||
2. gateway _send_voice_reply: fires based on voice_mode setting
|
||||
|
||||
To prevent double audio, _send_voice_reply is skipped when voice input
|
||||
already triggered base adapter auto-TTS (skip_double = is_voice_input).
|
||||
Exception: Discord voice channel — both auto-TTS and Discord play_tts
|
||||
override skip, so the runner must handle it via play_in_voice_channel.
|
||||
already triggered base adapter auto-TTS.
|
||||
|
||||
For Discord voice channels, the base adapter now routes play_tts directly
|
||||
into VC playback, so the runner should still skip voice-input follow-ups to
|
||||
avoid double playback.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
@@ -292,14 +295,14 @@ class TestAutoVoiceReply:
|
||||
|
||||
# -- Discord VC exception: runner must handle --------------------------
|
||||
|
||||
def test_discord_vc_voice_input_runner_fires(self, runner):
|
||||
"""Discord VC + voice input: base play_tts skips (VC override),
|
||||
so runner must handle via play_in_voice_channel."""
|
||||
assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is True
|
||||
def test_discord_vc_voice_input_base_handles(self, runner):
|
||||
"""Discord VC + voice input: base adapter play_tts plays in VC,
|
||||
so runner skips to avoid double playback."""
|
||||
assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is False
|
||||
|
||||
def test_discord_vc_voice_only_runner_fires(self, runner):
|
||||
"""Discord VC + voice_only + voice: runner must handle."""
|
||||
assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is True
|
||||
def test_discord_vc_voice_only_base_handles(self, runner):
|
||||
"""Discord VC + voice_only + voice: base adapter handles."""
|
||||
assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is False
|
||||
|
||||
# -- Edge cases --------------------------------------------------------
|
||||
|
||||
@@ -422,17 +425,23 @@ class TestDiscordPlayTtsSkip:
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_skipped_when_in_vc(self):
|
||||
async def test_play_tts_plays_in_vc_when_connected(self):
|
||||
adapter = self._make_discord_adapter()
|
||||
# Simulate bot in voice channel for guild 111, text channel 123
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
mock_vc.is_playing.return_value = False
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 123
|
||||
|
||||
# Mock play_in_voice_channel to avoid actual ffmpeg call
|
||||
async def fake_play(gid, path):
|
||||
return True
|
||||
adapter.play_in_voice_channel = fake_play
|
||||
|
||||
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/test.ogg")
|
||||
# play_tts now plays in VC instead of being a no-op
|
||||
assert result.success is True
|
||||
# send_voice should NOT have been called (no client, would fail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_not_skipped_when_not_in_vc(self):
|
||||
@@ -728,6 +737,24 @@ class TestVoiceChannelCommands:
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "failed" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_missing_voice_dependencies(self, runner):
|
||||
"""Missing PyNaCl/davey should return a user-actionable install hint."""
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.name = "General"
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.join_voice_channel = AsyncMock(
|
||||
side_effect=RuntimeError("PyNaCl library needed in order to use voice")
|
||||
)
|
||||
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
|
||||
event = self._make_discord_event()
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
|
||||
assert "voice dependencies are missing" in result.lower()
|
||||
assert "hermes-agent[messaging]" in result
|
||||
|
||||
# -- _handle_voice_channel_leave --
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -2031,3 +2058,534 @@ class TestDisconnectVoiceCleanup:
|
||||
assert len(adapter._voice_receivers) == 0
|
||||
assert len(adapter._voice_listen_tasks) == 0
|
||||
assert len(adapter._voice_timeout_tasks) == 0
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Discord Voice Channel Flow Tests
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("nacl") is None,
|
||||
reason="PyNaCl not installed",
|
||||
)
|
||||
class TestVoiceReception:
|
||||
"""Audio reception: SSRC mapping, DAVE passthrough, buffer lifecycle."""
|
||||
|
||||
@staticmethod
|
||||
def _make_receiver(allowed_ids=None, members=None, dave=False, bot_id=9999):
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = [0] * 32
|
||||
vc._connection.dave_session = MagicMock() if dave else None
|
||||
vc._connection.ssrc = bot_id
|
||||
vc._connection.add_socket_listener = MagicMock()
|
||||
vc._connection.remove_socket_listener = MagicMock()
|
||||
vc._connection.hook = None
|
||||
vc.user = SimpleNamespace(id=bot_id)
|
||||
vc.channel = MagicMock()
|
||||
vc.channel.members = members or []
|
||||
receiver = VoiceReceiver(vc, allowed_user_ids=allowed_ids)
|
||||
return receiver
|
||||
|
||||
@staticmethod
|
||||
def _fill_buffer(receiver, ssrc, duration_s=1.0, age_s=3.0):
|
||||
"""Add PCM data to buffer. 48kHz stereo 16-bit = 192000 bytes/sec."""
|
||||
size = int(192000 * duration_s)
|
||||
receiver._buffers[ssrc] = bytearray(b"\x00" * size)
|
||||
receiver._last_packet_time[ssrc] = time.monotonic() - age_s
|
||||
|
||||
# -- Known SSRC (normal flow) --
|
||||
|
||||
def test_known_ssrc_returns_completed(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.map_ssrc(100, 42)
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
assert len(receiver._buffers[100]) == 0 # cleared
|
||||
|
||||
def test_known_ssrc_short_buffer_ignored(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.map_ssrc(100, 42)
|
||||
self._fill_buffer(receiver, 100, duration_s=0.1) # too short
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_known_ssrc_recent_audio_waits(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.map_ssrc(100, 42)
|
||||
self._fill_buffer(receiver, 100, age_s=0.0) # just arrived
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
# -- Unknown SSRC + DAVE passthrough --
|
||||
|
||||
def test_unknown_ssrc_no_automap_no_completed(self):
|
||||
"""Unknown SSRC, no members to infer — buffer cleared, not returned."""
|
||||
receiver = self._make_receiver(dave=True, members=[])
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
assert len(receiver._buffers[100]) == 0
|
||||
|
||||
def test_unknown_ssrc_late_speaking_event(self):
|
||||
"""Audio buffered before SPEAKING → SPEAKING maps → next check returns it."""
|
||||
receiver = self._make_receiver(dave=True)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100, age_s=0.0) # still receiving
|
||||
# No user yet
|
||||
assert receiver.check_silence() == []
|
||||
# SPEAKING event arrives
|
||||
receiver.map_ssrc(100, 42)
|
||||
# Silence kicks in
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
# -- SSRC auto-mapping --
|
||||
|
||||
def test_automap_single_allowed_user(self):
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = self._make_receiver(allowed_ids={"42"}, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
assert receiver._ssrc_to_user[100] == 42
|
||||
|
||||
def test_automap_multiple_allowed_users_no_map(self):
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
SimpleNamespace(id=43, name="Bob"),
|
||||
]
|
||||
receiver = self._make_receiver(allowed_ids={"42", "43"}, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_automap_no_allowlist_single_member(self):
|
||||
"""No allowed_user_ids → sole non-bot member inferred."""
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = self._make_receiver(allowed_ids=None, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
def test_automap_unallowed_user_rejected(self):
|
||||
"""User in channel but not in allowed list — not mapped."""
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = self._make_receiver(allowed_ids={"99"}, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_automap_only_bot_in_channel(self):
|
||||
"""Only bot in channel — no one to map to."""
|
||||
members = [SimpleNamespace(id=9999, name="Bot")]
|
||||
receiver = self._make_receiver(allowed_ids=None, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_automap_persists_across_calls(self):
|
||||
"""Auto-mapped SSRC stays mapped for subsequent checks."""
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = self._make_receiver(allowed_ids={"42"}, members=members)
|
||||
receiver.start()
|
||||
self._fill_buffer(receiver, 100)
|
||||
receiver.check_silence()
|
||||
assert receiver._ssrc_to_user[100] == 42
|
||||
# Second utterance — should use cached mapping
|
||||
self._fill_buffer(receiver, 100)
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
# -- Stale buffer cleanup --
|
||||
|
||||
def test_stale_unknown_buffer_discarded(self):
|
||||
"""Buffer with no user and very old timestamp is discarded."""
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver._buffers[200] = bytearray(b"\x00" * 100)
|
||||
receiver._last_packet_time[200] = time.monotonic() - 10.0
|
||||
receiver.check_silence()
|
||||
assert 200 not in receiver._buffers
|
||||
|
||||
# -- Pause / resume (echo prevention) --
|
||||
|
||||
def test_paused_receiver_ignores_packets(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.pause()
|
||||
receiver._on_packet(b"\x00" * 100)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_resumed_receiver_accepts_packets(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.pause()
|
||||
receiver.resume()
|
||||
assert receiver._paused is False
|
||||
|
||||
# -- _on_packet DAVE passthrough behavior --
|
||||
|
||||
def _make_receiver_with_nacl(self, dave_session=None, mapped_ssrcs=None):
|
||||
"""Create a receiver that can process _on_packet with mocked NaCl + Opus."""
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = [0] * 32
|
||||
vc._connection.dave_session = dave_session
|
||||
vc._connection.ssrc = 9999
|
||||
vc._connection.add_socket_listener = MagicMock()
|
||||
vc._connection.remove_socket_listener = MagicMock()
|
||||
vc._connection.hook = None
|
||||
vc.user = SimpleNamespace(id=9999)
|
||||
vc.channel = MagicMock()
|
||||
vc.channel.members = []
|
||||
receiver = VoiceReceiver(vc)
|
||||
receiver.start()
|
||||
# Pre-map SSRCs if provided
|
||||
if mapped_ssrcs:
|
||||
for ssrc, uid in mapped_ssrcs.items():
|
||||
receiver.map_ssrc(ssrc, uid)
|
||||
return receiver
|
||||
|
||||
@staticmethod
|
||||
def _build_rtp_packet(ssrc=100, seq=1, timestamp=960):
|
||||
"""Build a minimal valid RTP packet for _on_packet.
|
||||
|
||||
We need: RTP header (12 bytes) + encrypted payload + 4-byte nonce.
|
||||
NaCl decrypt is mocked so payload content doesn't matter.
|
||||
"""
|
||||
import struct
|
||||
# RTP header: version=2, payload_type=0x78, no extension, no CSRC
|
||||
header = struct.pack(">BBHII", 0x80, 0x78, seq, timestamp, ssrc)
|
||||
# Fake encrypted payload (NaCl will be mocked) + 4 byte nonce
|
||||
payload = b"\x00" * 20 + b"\x00\x00\x00\x01"
|
||||
return header + payload
|
||||
|
||||
def _inject_mock_decoder(self, receiver, ssrc):
|
||||
"""Pre-inject a mock Opus decoder for the given SSRC."""
|
||||
mock_decoder = MagicMock()
|
||||
mock_decoder.decode.return_value = b"\x00" * 3840
|
||||
receiver._decoders[ssrc] = mock_decoder
|
||||
return mock_decoder
|
||||
|
||||
def test_on_packet_dave_known_user_decrypt_ok(self):
|
||||
"""Known SSRC + DAVE decrypt success → audio buffered."""
|
||||
dave = MagicMock()
|
||||
dave.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver = self._make_receiver_with_nacl(
|
||||
dave_session=dave, mapped_ssrcs={100: 42}
|
||||
)
|
||||
self._inject_mock_decoder(receiver, 100)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
dave.decrypt.assert_called_once()
|
||||
|
||||
def test_on_packet_dave_unknown_ssrc_passthrough(self):
|
||||
"""Unknown SSRC + DAVE → skip DAVE, attempt Opus decode (passthrough)."""
|
||||
dave = MagicMock()
|
||||
receiver = self._make_receiver_with_nacl(dave_session=dave)
|
||||
self._inject_mock_decoder(receiver, 100)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
|
||||
dave.decrypt.assert_not_called()
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_on_packet_dave_unencrypted_error_passthrough(self):
|
||||
"""DAVE decrypt 'Unencrypted' error → use data as-is, don't drop."""
|
||||
dave = MagicMock()
|
||||
dave.decrypt.side_effect = Exception(
|
||||
"Failed to decrypt: DecryptionFailed(UnencryptedWhenPassthroughDisabled)"
|
||||
)
|
||||
receiver = self._make_receiver_with_nacl(
|
||||
dave_session=dave, mapped_ssrcs={100: 42}
|
||||
)
|
||||
self._inject_mock_decoder(receiver, 100)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_on_packet_dave_other_error_drops(self):
|
||||
"""DAVE decrypt non-Unencrypted error → packet dropped."""
|
||||
dave = MagicMock()
|
||||
dave.decrypt.side_effect = Exception("KeyRotationFailed")
|
||||
receiver = self._make_receiver_with_nacl(
|
||||
dave_session=dave, mapped_ssrcs={100: 42}
|
||||
)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
def test_on_packet_no_dave_direct_decode(self):
|
||||
"""No DAVE session → decode directly."""
|
||||
receiver = self._make_receiver_with_nacl(dave_session=None)
|
||||
self._inject_mock_decoder(receiver, 100)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_on_packet_bot_own_ssrc_ignored(self):
|
||||
"""Bot's own SSRC → dropped (echo prevention)."""
|
||||
receiver = self._make_receiver_with_nacl()
|
||||
with patch("nacl.secret.Aead"):
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=9999))
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_on_packet_multiple_ssrcs_separate_buffers(self):
|
||||
"""Different SSRCs → separate buffers."""
|
||||
receiver = self._make_receiver_with_nacl(dave_session=None)
|
||||
self._inject_mock_decoder(receiver, 100)
|
||||
self._inject_mock_decoder(receiver, 200)
|
||||
|
||||
with patch("nacl.secret.Aead") as mock_aead:
|
||||
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||
receiver._on_packet(self._build_rtp_packet(ssrc=200))
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert 200 in receiver._buffers
|
||||
|
||||
|
||||
class TestVoiceTTSPlayback:
|
||||
"""TTS playback: play_tts in VC, dedup, fallback."""
|
||||
|
||||
@staticmethod
|
||||
def _make_discord_adapter():
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_receivers = {}
|
||||
return adapter
|
||||
|
||||
# -- play_tts behavior --
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_plays_in_vc(self):
|
||||
"""play_tts calls play_in_voice_channel when bot is in VC."""
|
||||
adapter = self._make_discord_adapter()
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 123
|
||||
|
||||
played = []
|
||||
async def fake_play(gid, path):
|
||||
played.append((gid, path))
|
||||
return True
|
||||
adapter.play_in_voice_channel = fake_play
|
||||
|
||||
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/tts.ogg")
|
||||
assert result.success is True
|
||||
assert played == [(111, "/tmp/tts.ogg")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_fallback_when_not_in_vc(self):
|
||||
"""play_tts sends as file attachment when bot is not in VC."""
|
||||
adapter = self._make_discord_adapter()
|
||||
from gateway.platforms.base import SendResult
|
||||
adapter.send_voice = AsyncMock(return_value=SendResult(success=False, error="no client"))
|
||||
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/tts.ogg")
|
||||
assert result.success is False
|
||||
adapter.send_voice.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_wrong_channel_no_match(self):
|
||||
"""play_tts doesn't match if chat_id is for a different channel."""
|
||||
adapter = self._make_discord_adapter()
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 123
|
||||
|
||||
from gateway.platforms.base import SendResult
|
||||
adapter.send_voice = AsyncMock(return_value=SendResult(success=True))
|
||||
# Different chat_id — shouldn't match VC
|
||||
result = await adapter.play_tts(chat_id="999", audio_path="/tmp/tts.ogg")
|
||||
adapter.send_voice.assert_called_once()
|
||||
|
||||
# -- Runner dedup --
|
||||
|
||||
@staticmethod
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._voice_mode = {}
|
||||
runner.adapters = {}
|
||||
return runner
|
||||
|
||||
def _call_should_reply(self, runner, voice_mode, msg_type, response="Hello", agent_msgs=None):
|
||||
from gateway.platforms.base import MessageType, MessageEvent, SessionSource
|
||||
from gateway.config import Platform
|
||||
runner._voice_mode["ch1"] = voice_mode
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD, chat_id="ch1",
|
||||
user_id="1", user_name="test", chat_type="channel",
|
||||
)
|
||||
event = MessageEvent(source=source, text="test", message_type=msg_type)
|
||||
return runner._should_send_voice_reply(event, response, agent_msgs or [])
|
||||
|
||||
def test_voice_input_runner_skips(self):
|
||||
"""Voice input: runner skips — base adapter handles via play_tts."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "all", MessageType.VOICE) is False
|
||||
|
||||
def test_text_input_voice_all_runner_fires(self):
|
||||
"""Text input + voice_mode=all: runner generates TTS."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "all", MessageType.TEXT) is True
|
||||
|
||||
def test_text_input_voice_off_no_tts(self):
|
||||
"""Text input + voice_mode=off: no TTS."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "off", MessageType.TEXT) is False
|
||||
|
||||
def test_text_input_voice_only_no_tts(self):
|
||||
"""Text input + voice_mode=voice_only: no TTS for text."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "voice_only", MessageType.TEXT) is False
|
||||
|
||||
def test_error_response_no_tts(self):
|
||||
"""Error response: no TTS regardless of voice_mode."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "all", MessageType.TEXT, response="Error: boom") is False
|
||||
|
||||
def test_empty_response_no_tts(self):
|
||||
"""Empty response: no TTS."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
assert self._call_should_reply(runner, "all", MessageType.TEXT, response="") is False
|
||||
|
||||
def test_agent_tts_tool_dedup(self):
|
||||
"""Agent already called text_to_speech tool: runner skips."""
|
||||
from gateway.platforms.base import MessageType
|
||||
runner = self._make_runner()
|
||||
agent_msgs = [{"role": "assistant", "tool_calls": [
|
||||
{"id": "1", "type": "function", "function": {"name": "text_to_speech", "arguments": "{}"}}
|
||||
]}]
|
||||
assert self._call_should_reply(runner, "all", MessageType.TEXT, agent_msgs=agent_msgs) is False
|
||||
|
||||
|
||||
class TestUDPKeepalive:
|
||||
"""UDP keepalive prevents Discord from dropping the voice session."""
|
||||
|
||||
def test_keepalive_interval_is_reasonable(self):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
interval = DiscordAdapter._KEEPALIVE_INTERVAL
|
||||
assert 5 <= interval <= 30, f"Keepalive interval {interval}s should be between 5-30s"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keepalive_sends_silence_frame(self):
|
||||
"""Listen loop sends silence frame via send_packet after interval."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake"
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
|
||||
# Mock VC and receiver
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
mock_conn = MagicMock()
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
mock_vc._connection = mock_conn
|
||||
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
mock_receiver_vc = MagicMock()
|
||||
mock_receiver_vc._connection.secret_key = [0] * 32
|
||||
mock_receiver_vc._connection.dave_session = None
|
||||
mock_receiver_vc._connection.ssrc = 9999
|
||||
mock_receiver_vc._connection.add_socket_listener = MagicMock()
|
||||
mock_receiver_vc._connection.remove_socket_listener = MagicMock()
|
||||
mock_receiver_vc._connection.hook = None
|
||||
receiver = VoiceReceiver(mock_receiver_vc)
|
||||
receiver.start()
|
||||
adapter._voice_receivers[111] = receiver
|
||||
|
||||
# Set keepalive interval very short for test
|
||||
original_interval = DiscordAdapter._KEEPALIVE_INTERVAL
|
||||
DiscordAdapter._KEEPALIVE_INTERVAL = 0.1
|
||||
|
||||
try:
|
||||
# Run listen loop briefly
|
||||
import asyncio
|
||||
loop_task = asyncio.create_task(adapter._voice_listen_loop(111))
|
||||
await asyncio.sleep(0.3)
|
||||
receiver._running = False # stop loop
|
||||
await asyncio.sleep(0.1)
|
||||
loop_task.cancel()
|
||||
try:
|
||||
await loop_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# send_packet should have been called with silence frame
|
||||
mock_conn.send_packet.assert_called_with(b'\xf8\xff\xfe')
|
||||
finally:
|
||||
DiscordAdapter._KEEPALIVE_INTERVAL = original_interval
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import sys
|
||||
|
||||
|
||||
def test_top_level_skills_flag_defaults_to_chat(monkeypatch):
|
||||
import hermes_cli.main as main_mod
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_cmd_chat(args):
|
||||
captured["skills"] = args.skills
|
||||
captured["command"] = args.command
|
||||
|
||||
monkeypatch.setattr(main_mod, "cmd_chat", fake_cmd_chat)
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
["hermes", "-s", "hermes-agent-dev,github-auth"],
|
||||
)
|
||||
|
||||
main_mod.main()
|
||||
|
||||
assert captured == {
|
||||
"skills": ["hermes-agent-dev,github-auth"],
|
||||
"command": None,
|
||||
}
|
||||
|
||||
|
||||
def test_chat_subcommand_accepts_skills_flag(monkeypatch):
|
||||
import hermes_cli.main as main_mod
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_cmd_chat(args):
|
||||
captured["skills"] = args.skills
|
||||
captured["query"] = args.query
|
||||
|
||||
monkeypatch.setattr(main_mod, "cmd_chat", fake_cmd_chat)
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
["hermes", "chat", "-s", "github-auth", "-q", "hello"],
|
||||
)
|
||||
|
||||
main_mod.main()
|
||||
|
||||
assert captured == {
|
||||
"skills": ["github-auth"],
|
||||
"query": "hello",
|
||||
}
|
||||
|
||||
|
||||
def test_continue_worktree_and_skills_flags_work_together(monkeypatch):
|
||||
import hermes_cli.main as main_mod
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_cmd_chat(args):
|
||||
captured["continue_last"] = args.continue_last
|
||||
captured["worktree"] = args.worktree
|
||||
captured["skills"] = args.skills
|
||||
captured["command"] = args.command
|
||||
|
||||
monkeypatch.setattr(main_mod, "cmd_chat", fake_cmd_chat)
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
["hermes", "-c", "-w", "-s", "hermes-agent-dev"],
|
||||
)
|
||||
|
||||
main_mod.main()
|
||||
|
||||
assert captured == {
|
||||
"continue_last": True,
|
||||
"worktree": True,
|
||||
"skills": ["hermes-agent-dev"],
|
||||
"command": "chat",
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Tests for cmd_update — branch fallback when remote branch doesn't exist."""
|
||||
|
||||
import subprocess
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.main import cmd_update, PROJECT_ROOT
|
||||
|
||||
|
||||
def _make_run_side_effect(branch="main", verify_ok=True, commit_count="0"):
|
||||
"""Build a side_effect function for subprocess.run that simulates git commands."""
|
||||
|
||||
def side_effect(cmd, **kwargs):
|
||||
joined = " ".join(str(c) for c in cmd)
|
||||
|
||||
# git rev-parse --abbrev-ref HEAD (get current branch)
|
||||
if "rev-parse" in joined and "--abbrev-ref" in joined:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout=f"{branch}\n", stderr="")
|
||||
|
||||
# git rev-parse --verify origin/{branch} (check remote branch exists)
|
||||
if "rev-parse" in joined and "--verify" in joined:
|
||||
rc = 0 if verify_ok else 128
|
||||
return subprocess.CompletedProcess(cmd, rc, stdout="", stderr="")
|
||||
|
||||
# git rev-list HEAD..origin/{branch} --count
|
||||
if "rev-list" in joined:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout=f"{commit_count}\n", stderr="")
|
||||
|
||||
# Fallback: return a successful CompletedProcess with empty stdout
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="")
|
||||
|
||||
return side_effect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_args():
|
||||
return SimpleNamespace()
|
||||
|
||||
|
||||
class TestCmdUpdateBranchFallback:
|
||||
"""cmd_update falls back to main when current branch has no remote counterpart."""
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_falls_back_to_main_when_branch_not_on_remote(
|
||||
self, mock_run, _mock_which, mock_args, capsys
|
||||
):
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
branch="fix/stoicneko", verify_ok=False, commit_count="3"
|
||||
)
|
||||
|
||||
cmd_update(mock_args)
|
||||
|
||||
commands = [" ".join(str(a) for a in c.args[0]) for c in mock_run.call_args_list]
|
||||
|
||||
# rev-list should use origin/main, not origin/fix/stoicneko
|
||||
rev_list_cmds = [c for c in commands if "rev-list" in c]
|
||||
assert len(rev_list_cmds) == 1
|
||||
assert "origin/main" in rev_list_cmds[0]
|
||||
assert "origin/fix/stoicneko" not in rev_list_cmds[0]
|
||||
|
||||
# pull should use main, not fix/stoicneko
|
||||
pull_cmds = [c for c in commands if "pull" in c]
|
||||
assert len(pull_cmds) == 1
|
||||
assert "main" in pull_cmds[0]
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_uses_current_branch_when_on_remote(
|
||||
self, mock_run, _mock_which, mock_args, capsys
|
||||
):
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
branch="main", verify_ok=True, commit_count="2"
|
||||
)
|
||||
|
||||
cmd_update(mock_args)
|
||||
|
||||
commands = [" ".join(str(a) for a in c.args[0]) for c in mock_run.call_args_list]
|
||||
|
||||
rev_list_cmds = [c for c in commands if "rev-list" in c]
|
||||
assert len(rev_list_cmds) == 1
|
||||
assert "origin/main" in rev_list_cmds[0]
|
||||
|
||||
pull_cmds = [c for c in commands if "pull" in c]
|
||||
assert len(pull_cmds) == 1
|
||||
assert "main" in pull_cmds[0]
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_already_up_to_date(
|
||||
self, mock_run, _mock_which, mock_args, capsys
|
||||
):
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
branch="main", verify_ok=True, commit_count="0"
|
||||
)
|
||||
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Already up to date!" in captured.out
|
||||
|
||||
# Should NOT have called pull
|
||||
commands = [" ".join(str(a) for a in c.args[0]) for c in mock_run.call_args_list]
|
||||
pull_cmds = [c for c in commands if "pull" in c]
|
||||
assert len(pull_cmds) == 0
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Tests for hermes_cli.cron command handling."""
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
import pytest
|
||||
|
||||
from cron.jobs import create_job, get_job, list_jobs
|
||||
from hermes_cli.cron import cron_command
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_cron_dir(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("cron.jobs.CRON_DIR", tmp_path / "cron")
|
||||
monkeypatch.setattr("cron.jobs.JOBS_FILE", tmp_path / "cron" / "jobs.json")
|
||||
monkeypatch.setattr("cron.jobs.OUTPUT_DIR", tmp_path / "cron" / "output")
|
||||
return tmp_path
|
||||
|
||||
|
||||
class TestCronCommandLifecycle:
|
||||
def test_pause_resume_run(self, tmp_cron_dir, capsys):
|
||||
job = create_job(prompt="Check server status", schedule="every 1h")
|
||||
|
||||
cron_command(Namespace(cron_command="pause", job_id=job["id"]))
|
||||
paused = get_job(job["id"])
|
||||
assert paused["state"] == "paused"
|
||||
|
||||
cron_command(Namespace(cron_command="resume", job_id=job["id"]))
|
||||
resumed = get_job(job["id"])
|
||||
assert resumed["state"] == "scheduled"
|
||||
|
||||
cron_command(Namespace(cron_command="run", job_id=job["id"]))
|
||||
triggered = get_job(job["id"])
|
||||
assert triggered["state"] == "scheduled"
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Paused job" in out
|
||||
assert "Resumed job" in out
|
||||
assert "Triggered job" in out
|
||||
|
||||
def test_edit_can_replace_and_clear_skills(self, tmp_cron_dir, capsys):
|
||||
job = create_job(
|
||||
prompt="Combine skill outputs",
|
||||
schedule="every 1h",
|
||||
skill="blogwatcher",
|
||||
)
|
||||
|
||||
cron_command(
|
||||
Namespace(
|
||||
cron_command="edit",
|
||||
job_id=job["id"],
|
||||
schedule="every 2h",
|
||||
prompt="Revised prompt",
|
||||
name="Edited Job",
|
||||
deliver=None,
|
||||
repeat=None,
|
||||
skill=None,
|
||||
skills=["find-nearby", "blogwatcher"],
|
||||
clear_skills=False,
|
||||
)
|
||||
)
|
||||
updated = get_job(job["id"])
|
||||
assert updated["skills"] == ["find-nearby", "blogwatcher"]
|
||||
assert updated["name"] == "Edited Job"
|
||||
assert updated["prompt"] == "Revised prompt"
|
||||
assert updated["schedule_display"] == "every 120m"
|
||||
|
||||
cron_command(
|
||||
Namespace(
|
||||
cron_command="edit",
|
||||
job_id=job["id"],
|
||||
schedule=None,
|
||||
prompt=None,
|
||||
name=None,
|
||||
deliver=None,
|
||||
repeat=None,
|
||||
skill=None,
|
||||
skills=None,
|
||||
clear_skills=True,
|
||||
)
|
||||
)
|
||||
cleared = get_job(job["id"])
|
||||
assert cleared["skills"] == []
|
||||
assert cleared["skill"] is None
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Updated job" in out
|
||||
|
||||
def test_create_with_multiple_skills(self, tmp_cron_dir, capsys):
|
||||
cron_command(
|
||||
Namespace(
|
||||
cron_command="create",
|
||||
schedule="every 1h",
|
||||
prompt="Use both skills",
|
||||
name="Skill combo",
|
||||
deliver=None,
|
||||
repeat=None,
|
||||
skill=None,
|
||||
skills=["blogwatcher", "find-nearby"],
|
||||
)
|
||||
)
|
||||
out = capsys.readouterr().out
|
||||
assert "Created job" in out
|
||||
|
||||
jobs = list_jobs()
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0]["skills"] == ["blogwatcher", "find-nearby"]
|
||||
assert jobs[0]["name"] == "Skill combo"
|
||||
@@ -0,0 +1,70 @@
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
|
||||
|
||||
def test_user_env_overrides_stale_shell_values(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
home.mkdir()
|
||||
env_file = home / ".env"
|
||||
env_file.write_text("OPENAI_BASE_URL=https://new.example/v1\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://old.example/v1")
|
||||
|
||||
loaded = load_hermes_dotenv(hermes_home=home)
|
||||
|
||||
assert loaded == [env_file]
|
||||
assert os.getenv("OPENAI_BASE_URL") == "https://new.example/v1"
|
||||
|
||||
|
||||
def test_project_env_overrides_stale_shell_values_when_user_env_missing(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
project_env = tmp_path / ".env"
|
||||
project_env.write_text("OPENAI_BASE_URL=https://project.example/v1\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://old.example/v1")
|
||||
|
||||
loaded = load_hermes_dotenv(hermes_home=home, project_env=project_env)
|
||||
|
||||
assert loaded == [project_env]
|
||||
assert os.getenv("OPENAI_BASE_URL") == "https://project.example/v1"
|
||||
|
||||
|
||||
def test_user_env_takes_precedence_over_project_env(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
home.mkdir()
|
||||
user_env = home / ".env"
|
||||
project_env = tmp_path / ".env"
|
||||
user_env.write_text("OPENAI_BASE_URL=https://user.example/v1\n", encoding="utf-8")
|
||||
project_env.write_text("OPENAI_BASE_URL=https://project.example/v1\nOPENAI_API_KEY=project-key\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://old.example/v1")
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
|
||||
loaded = load_hermes_dotenv(hermes_home=home, project_env=project_env)
|
||||
|
||||
assert loaded == [user_env, project_env]
|
||||
assert os.getenv("OPENAI_BASE_URL") == "https://user.example/v1"
|
||||
assert os.getenv("OPENAI_API_KEY") == "project-key"
|
||||
|
||||
|
||||
def test_main_import_applies_user_env_over_shell_values(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
home.mkdir()
|
||||
(home / ".env").write_text(
|
||||
"OPENAI_BASE_URL=https://new.example/v1\nHERMES_INFERENCE_PROVIDER=custom\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://old.example/v1")
|
||||
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openrouter")
|
||||
|
||||
sys.modules.pop("hermes_cli.main", None)
|
||||
importlib.import_module("hermes_cli.main")
|
||||
|
||||
assert os.getenv("OPENAI_BASE_URL") == "https://new.example/v1"
|
||||
assert os.getenv("HERMES_INFERENCE_PROVIDER") == "custom"
|
||||
@@ -35,7 +35,7 @@ def test_systemd_status_warns_when_linger_disabled(monkeypatch, tmp_path, capsys
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
unit_path.write_text("[Unit]\n")
|
||||
|
||||
monkeypatch.setattr(gateway, "get_systemd_unit_path", lambda: unit_path)
|
||||
monkeypatch.setattr(gateway, "get_systemd_unit_path", lambda system=False: unit_path)
|
||||
monkeypatch.setattr(gateway, "get_systemd_linger_status", lambda: (False, ""))
|
||||
|
||||
def fake_run(cmd, capture_output=False, text=False, check=False):
|
||||
@@ -50,7 +50,7 @@ def test_systemd_status_warns_when_linger_disabled(monkeypatch, tmp_path, capsys
|
||||
gateway.systemd_status(deep=False)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Gateway service is running" in out
|
||||
assert "gateway service is running" in out
|
||||
assert "Systemd linger is disabled" in out
|
||||
assert "loginctl enable-linger" in out
|
||||
|
||||
@@ -58,16 +58,17 @@ def test_systemd_status_warns_when_linger_disabled(monkeypatch, tmp_path, capsys
|
||||
def test_systemd_install_checks_linger_status(monkeypatch, tmp_path, capsys):
|
||||
unit_path = tmp_path / "systemd" / "user" / "hermes-gateway.service"
|
||||
|
||||
monkeypatch.setattr(gateway, "get_systemd_unit_path", lambda: unit_path)
|
||||
monkeypatch.setattr(gateway, "get_systemd_linger_status", lambda: (False, ""))
|
||||
monkeypatch.setattr(gateway, "get_systemd_unit_path", lambda system=False: unit_path)
|
||||
|
||||
calls = []
|
||||
helper_calls = []
|
||||
|
||||
def fake_run(cmd, check=False, **kwargs):
|
||||
calls.append((cmd, check))
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(gateway.subprocess, "run", fake_run)
|
||||
monkeypatch.setattr(gateway, "_ensure_linger_enabled", lambda: helper_calls.append(True))
|
||||
|
||||
gateway.systemd_install(force=False)
|
||||
|
||||
@@ -77,6 +78,94 @@ def test_systemd_install_checks_linger_status(monkeypatch, tmp_path, capsys):
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "enable", gateway.SERVICE_NAME],
|
||||
]
|
||||
assert "Service installed and enabled" in out
|
||||
assert "Systemd linger is disabled" in out
|
||||
assert "loginctl enable-linger" in out
|
||||
assert helper_calls == [True]
|
||||
assert "User service installed and enabled" in out
|
||||
|
||||
|
||||
def test_systemd_install_system_scope_skips_linger_and_uses_systemctl(monkeypatch, tmp_path, capsys):
|
||||
unit_path = tmp_path / "etc" / "systemd" / "system" / "hermes-gateway.service"
|
||||
|
||||
monkeypatch.setattr(gateway, "get_systemd_unit_path", lambda system=False: unit_path)
|
||||
monkeypatch.setattr(
|
||||
gateway,
|
||||
"generate_systemd_unit",
|
||||
lambda system=False, run_as_user=None: f"scope={system} user={run_as_user}\n",
|
||||
)
|
||||
monkeypatch.setattr(gateway, "_require_root_for_system_service", lambda action: None)
|
||||
|
||||
calls = []
|
||||
helper_calls = []
|
||||
|
||||
def fake_run(cmd, check=False, **kwargs):
|
||||
calls.append((cmd, check))
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(gateway.subprocess, "run", fake_run)
|
||||
monkeypatch.setattr(gateway, "_ensure_linger_enabled", lambda: helper_calls.append(True))
|
||||
|
||||
gateway.systemd_install(force=False, system=True, run_as_user="alice")
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert unit_path.exists()
|
||||
assert unit_path.read_text(encoding="utf-8") == "scope=True user=alice\n"
|
||||
assert [cmd for cmd, _ in calls] == [
|
||||
["systemctl", "daemon-reload"],
|
||||
["systemctl", "enable", gateway.SERVICE_NAME],
|
||||
]
|
||||
assert helper_calls == []
|
||||
assert "Configured to run as: alice" not in out # generated test unit has no User= line
|
||||
assert "System service installed and enabled" in out
|
||||
|
||||
|
||||
def test_conflicting_systemd_units_warning(monkeypatch, tmp_path, capsys):
|
||||
user_unit = tmp_path / "user" / "hermes-gateway.service"
|
||||
system_unit = tmp_path / "system" / "hermes-gateway.service"
|
||||
user_unit.parent.mkdir(parents=True)
|
||||
system_unit.parent.mkdir(parents=True)
|
||||
user_unit.write_text("[Unit]\n", encoding="utf-8")
|
||||
system_unit.write_text("[Unit]\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway,
|
||||
"get_systemd_unit_path",
|
||||
lambda system=False: system_unit if system else user_unit,
|
||||
)
|
||||
|
||||
gateway.print_systemd_scope_conflict_warning()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Both user and system gateway services are installed" in out
|
||||
assert "hermes gateway uninstall" in out
|
||||
assert "--system" in out
|
||||
|
||||
|
||||
def test_install_linux_gateway_from_setup_system_choice_without_root_prints_followup(monkeypatch, capsys):
|
||||
monkeypatch.setattr(gateway, "prompt_linux_gateway_install_scope", lambda: "system")
|
||||
monkeypatch.setattr(gateway.os, "geteuid", lambda: 1000)
|
||||
monkeypatch.setattr(gateway, "_default_system_service_user", lambda: "alice")
|
||||
monkeypatch.setattr(gateway, "systemd_install", lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("should not install")))
|
||||
|
||||
scope, did_install = gateway.install_linux_gateway_from_setup(force=False)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert (scope, did_install) == ("system", False)
|
||||
assert "sudo hermes gateway install --system --run-as-user alice" in out
|
||||
assert "sudo hermes gateway start --system" in out
|
||||
|
||||
|
||||
def test_install_linux_gateway_from_setup_system_choice_as_root_installs(monkeypatch):
|
||||
monkeypatch.setattr(gateway, "prompt_linux_gateway_install_scope", lambda: "system")
|
||||
monkeypatch.setattr(gateway.os, "geteuid", lambda: 0)
|
||||
monkeypatch.setattr(gateway, "_default_system_service_user", lambda: "alice")
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
gateway,
|
||||
"systemd_install",
|
||||
lambda force=False, system=False, run_as_user=None: calls.append((force, system, run_as_user)),
|
||||
)
|
||||
|
||||
scope, did_install = gateway.install_linux_gateway_from_setup(force=True)
|
||||
|
||||
assert (scope, did_install) == ("system", True)
|
||||
assert calls == [(True, True, "alice")]
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
"""Tests for gateway linger auto-enable behavior on headless Linux installs."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import hermes_cli.gateway as gateway
|
||||
|
||||
|
||||
class TestEnsureLingerEnabled:
|
||||
def test_linger_already_enabled_via_file(self, monkeypatch, capsys):
|
||||
monkeypatch.setattr(gateway, "is_linux", lambda: True)
|
||||
monkeypatch.setattr("getpass.getuser", lambda: "testuser")
|
||||
monkeypatch.setattr(gateway, "Path", lambda _path: SimpleNamespace(exists=lambda: True))
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(gateway.subprocess, "run", lambda *args, **kwargs: calls.append((args, kwargs)))
|
||||
|
||||
gateway._ensure_linger_enabled()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Systemd linger is enabled" in out
|
||||
assert calls == []
|
||||
|
||||
def test_status_enabled_skips_enable(self, monkeypatch, capsys):
|
||||
monkeypatch.setattr(gateway, "is_linux", lambda: True)
|
||||
monkeypatch.setattr("getpass.getuser", lambda: "testuser")
|
||||
monkeypatch.setattr(gateway, "Path", lambda _path: SimpleNamespace(exists=lambda: False))
|
||||
monkeypatch.setattr(gateway, "get_systemd_linger_status", lambda: (True, ""))
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(gateway.subprocess, "run", lambda *args, **kwargs: calls.append((args, kwargs)))
|
||||
|
||||
gateway._ensure_linger_enabled()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Systemd linger is enabled" in out
|
||||
assert calls == []
|
||||
|
||||
def test_loginctl_success_enables_linger(self, monkeypatch, capsys):
|
||||
monkeypatch.setattr(gateway, "is_linux", lambda: True)
|
||||
monkeypatch.setattr("getpass.getuser", lambda: "testuser")
|
||||
monkeypatch.setattr(gateway, "Path", lambda _path: SimpleNamespace(exists=lambda: False))
|
||||
monkeypatch.setattr(gateway, "get_systemd_linger_status", lambda: (False, ""))
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/loginctl")
|
||||
|
||||
run_calls = []
|
||||
|
||||
def fake_run(cmd, capture_output=False, text=False, check=False):
|
||||
run_calls.append((cmd, capture_output, text, check))
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(gateway.subprocess, "run", fake_run)
|
||||
|
||||
gateway._ensure_linger_enabled()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Enabling linger" in out
|
||||
assert "Linger enabled" in out
|
||||
assert run_calls == [(["loginctl", "enable-linger", "testuser"], True, True, False)]
|
||||
|
||||
def test_missing_loginctl_shows_manual_guidance(self, monkeypatch, capsys):
|
||||
monkeypatch.setattr(gateway, "is_linux", lambda: True)
|
||||
monkeypatch.setattr("getpass.getuser", lambda: "testuser")
|
||||
monkeypatch.setattr(gateway, "Path", lambda _path: SimpleNamespace(exists=lambda: False))
|
||||
monkeypatch.setattr(gateway, "get_systemd_linger_status", lambda: (None, "loginctl not found"))
|
||||
monkeypatch.setattr("shutil.which", lambda name: None)
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(gateway.subprocess, "run", lambda *args, **kwargs: calls.append((args, kwargs)))
|
||||
|
||||
gateway._ensure_linger_enabled()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "sudo loginctl enable-linger testuser" in out
|
||||
assert "loginctl not found" in out
|
||||
assert calls == []
|
||||
|
||||
def test_loginctl_failure_shows_manual_guidance(self, monkeypatch, capsys):
|
||||
monkeypatch.setattr(gateway, "is_linux", lambda: True)
|
||||
monkeypatch.setattr("getpass.getuser", lambda: "testuser")
|
||||
monkeypatch.setattr(gateway, "Path", lambda _path: SimpleNamespace(exists=lambda: False))
|
||||
monkeypatch.setattr(gateway, "get_systemd_linger_status", lambda: (False, ""))
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/loginctl")
|
||||
monkeypatch.setattr(
|
||||
gateway.subprocess,
|
||||
"run",
|
||||
lambda *args, **kwargs: SimpleNamespace(returncode=1, stdout="", stderr="Permission denied"),
|
||||
)
|
||||
|
||||
gateway._ensure_linger_enabled()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "sudo loginctl enable-linger testuser" in out
|
||||
assert "Permission denied" in out
|
||||
|
||||
|
||||
def test_systemd_install_calls_linger_helper(monkeypatch, tmp_path, capsys):
|
||||
unit_path = tmp_path / "systemd" / "user" / "hermes-gateway.service"
|
||||
|
||||
monkeypatch.setattr(gateway, "get_systemd_unit_path", lambda system=False: unit_path)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, check=False, **kwargs):
|
||||
calls.append((cmd, check))
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
helper_calls = []
|
||||
monkeypatch.setattr(gateway.subprocess, "run", fake_run)
|
||||
monkeypatch.setattr(gateway, "_ensure_linger_enabled", lambda: helper_calls.append(True))
|
||||
|
||||
gateway.systemd_install(force=False)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert unit_path.exists()
|
||||
assert [cmd for cmd, _ in calls] == [
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "enable", gateway.SERVICE_NAME],
|
||||
]
|
||||
assert helper_calls == [True]
|
||||
assert "User service installed and enabled" in out
|
||||
@@ -0,0 +1,22 @@
|
||||
from hermes_cli.gateway import _runtime_health_lines
|
||||
|
||||
|
||||
def test_runtime_health_lines_include_fatal_platform_and_startup_reason(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.read_runtime_status",
|
||||
lambda: {
|
||||
"gateway_state": "startup_failed",
|
||||
"exit_reason": "telegram conflict",
|
||||
"platforms": {
|
||||
"telegram": {
|
||||
"state": "fatal",
|
||||
"error_message": "another poller is active",
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
lines = _runtime_health_lines()
|
||||
|
||||
assert "⚠ telegram: another poller is active" in lines
|
||||
assert "⚠ Last startup issue: telegram conflict" in lines
|
||||
@@ -10,8 +10,8 @@ class TestSystemdServiceRefresh:
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
unit_path.write_text("old unit\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "get_systemd_unit_path", lambda: unit_path)
|
||||
monkeypatch.setattr(gateway_cli, "generate_systemd_unit", lambda: "new unit\n")
|
||||
monkeypatch.setattr(gateway_cli, "get_systemd_unit_path", lambda system=False: unit_path)
|
||||
monkeypatch.setattr(gateway_cli, "generate_systemd_unit", lambda system=False, run_as_user=None: "new unit\n")
|
||||
|
||||
calls = []
|
||||
|
||||
@@ -33,8 +33,8 @@ class TestSystemdServiceRefresh:
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
unit_path.write_text("old unit\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "get_systemd_unit_path", lambda: unit_path)
|
||||
monkeypatch.setattr(gateway_cli, "generate_systemd_unit", lambda: "new unit\n")
|
||||
monkeypatch.setattr(gateway_cli, "get_systemd_unit_path", lambda system=False: unit_path)
|
||||
monkeypatch.setattr(gateway_cli, "generate_systemd_unit", lambda system=False, run_as_user=None: "new unit\n")
|
||||
|
||||
calls = []
|
||||
|
||||
@@ -60,12 +60,12 @@ class TestGatewayStopCleanup:
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "get_systemd_unit_path", lambda: unit_path)
|
||||
monkeypatch.setattr(gateway_cli, "get_systemd_unit_path", lambda system=False: unit_path)
|
||||
|
||||
service_calls = []
|
||||
kill_calls = []
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "systemd_stop", lambda: service_calls.append("stop"))
|
||||
monkeypatch.setattr(gateway_cli, "systemd_stop", lambda system=False: service_calls.append("stop"))
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"kill_gateway_processes",
|
||||
@@ -76,3 +76,66 @@ class TestGatewayStopCleanup:
|
||||
|
||||
assert service_calls == ["stop"]
|
||||
assert kill_calls == [False]
|
||||
|
||||
|
||||
class TestGatewayServiceDetection:
|
||||
def test_is_service_running_checks_system_scope_when_user_scope_is_inactive(self, monkeypatch):
|
||||
user_unit = SimpleNamespace(exists=lambda: True)
|
||||
system_unit = SimpleNamespace(exists=lambda: True)
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"get_systemd_unit_path",
|
||||
lambda system=False: system_unit if system else user_unit,
|
||||
)
|
||||
|
||||
def fake_run(cmd, capture_output=True, text=True, **kwargs):
|
||||
if cmd == ["systemctl", "--user", "is-active", gateway_cli.SERVICE_NAME]:
|
||||
return SimpleNamespace(returncode=0, stdout="inactive\n", stderr="")
|
||||
if cmd == ["systemctl", "is-active", gateway_cli.SERVICE_NAME]:
|
||||
return SimpleNamespace(returncode=0, stdout="active\n", stderr="")
|
||||
raise AssertionError(f"Unexpected command: {cmd}")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
|
||||
assert gateway_cli._is_service_running() is True
|
||||
|
||||
|
||||
class TestGatewaySystemServiceRouting:
|
||||
def test_gateway_install_passes_system_flags(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"systemd_install",
|
||||
lambda force=False, system=False, run_as_user=None: calls.append((force, system, run_as_user)),
|
||||
)
|
||||
|
||||
gateway_cli.gateway_command(
|
||||
SimpleNamespace(gateway_command="install", force=True, system=True, run_as_user="alice")
|
||||
)
|
||||
|
||||
assert calls == [(True, True, "alice")]
|
||||
|
||||
def test_gateway_status_prefers_system_service_when_only_system_unit_exists(self, monkeypatch):
|
||||
user_unit = SimpleNamespace(exists=lambda: False)
|
||||
system_unit = SimpleNamespace(exists=lambda: True)
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"get_systemd_unit_path",
|
||||
lambda system=False: system_unit if system else user_unit,
|
||||
)
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(gateway_cli, "systemd_status", lambda deep=False, system=False: calls.append((deep, system)))
|
||||
|
||||
gateway_cli.gateway_command(SimpleNamespace(gateway_command="status", deep=False, system=False))
|
||||
|
||||
assert calls == [(False, False)]
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Tests for CLI placeholder text in config/setup output."""
|
||||
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.config import config_command, show_config
|
||||
from hermes_cli.setup import _print_setup_summary
|
||||
|
||||
|
||||
def test_config_set_usage_marks_placeholders(capsys):
|
||||
args = Namespace(config_command="set", key=None, value=None)
|
||||
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
config_command(args)
|
||||
|
||||
assert exc.value.code == 1
|
||||
out = capsys.readouterr().out
|
||||
assert "Usage: hermes config set <key> <value>" in out
|
||||
|
||||
|
||||
def test_config_unknown_command_help_marks_placeholders(capsys):
|
||||
args = Namespace(config_command="wat")
|
||||
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
config_command(args)
|
||||
|
||||
assert exc.value.code == 1
|
||||
out = capsys.readouterr().out
|
||||
assert "hermes config set <key> <value> Set a config value" in out
|
||||
|
||||
|
||||
def test_show_config_marks_placeholders(tmp_path, capsys):
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
show_config()
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "hermes config set <key> <value>" in out
|
||||
|
||||
|
||||
def test_setup_summary_marks_placeholders(tmp_path, capsys):
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
_print_setup_summary({"tts": {"provider": "edge"}}, tmp_path)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "hermes config set <key> <value>" in out
|
||||
@@ -0,0 +1,64 @@
|
||||
import sys
|
||||
|
||||
|
||||
def test_sessions_delete_accepts_unique_id_prefix(monkeypatch, capsys):
|
||||
import hermes_cli.main as main_mod
|
||||
import hermes_state
|
||||
|
||||
captured = {}
|
||||
|
||||
class FakeDB:
|
||||
def resolve_session_id(self, session_id):
|
||||
captured["resolved_from"] = session_id
|
||||
return "20260315_092437_c9a6ff"
|
||||
|
||||
def delete_session(self, session_id):
|
||||
captured["deleted"] = session_id
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
captured["closed"] = True
|
||||
|
||||
monkeypatch.setattr(hermes_state, "SessionDB", lambda: FakeDB())
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
["hermes", "sessions", "delete", "20260315_092437_c9a6", "--yes"],
|
||||
)
|
||||
|
||||
main_mod.main()
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert captured == {
|
||||
"resolved_from": "20260315_092437_c9a6",
|
||||
"deleted": "20260315_092437_c9a6ff",
|
||||
"closed": True,
|
||||
}
|
||||
assert "Deleted session '20260315_092437_c9a6ff'." in output
|
||||
|
||||
|
||||
def test_sessions_delete_reports_not_found_when_prefix_is_unknown(monkeypatch, capsys):
|
||||
import hermes_cli.main as main_mod
|
||||
import hermes_state
|
||||
|
||||
class FakeDB:
|
||||
def resolve_session_id(self, session_id):
|
||||
return None
|
||||
|
||||
def delete_session(self, session_id):
|
||||
raise AssertionError("delete_session should not be called when resolution fails")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(hermes_state, "SessionDB", lambda: FakeDB())
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
["hermes", "sessions", "delete", "missing-prefix", "--yes"],
|
||||
)
|
||||
|
||||
main_mod.main()
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Session 'missing-prefix' not found." in output
|
||||
@@ -25,7 +25,11 @@ def test_nous_oauth_setup_keeps_current_model_when_syncing_disk_provider(
|
||||
|
||||
config = load_config()
|
||||
|
||||
prompt_choices = iter([0, 2])
|
||||
# Provider selection always comes first. Depending on available vision
|
||||
# backends, setup may either skip the optional vision step or prompt for
|
||||
# it before the default-model choice. Provide enough selections for both
|
||||
# paths while still ending on "keep current model".
|
||||
prompt_choices = iter([0, 2, 2])
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.setup.prompt_choice",
|
||||
lambda *args, **kwargs: next(prompt_choices),
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from hermes_cli.config import load_config, save_config, save_env_value
|
||||
from hermes_cli.setup import setup_model_provider
|
||||
from hermes_cli.setup import _print_setup_summary, setup_model_provider
|
||||
|
||||
|
||||
def _read_env(home):
|
||||
@@ -39,6 +39,8 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m
|
||||
"""Keep-current custom should not fall through to the generic model menu."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
save_env_value("OPENAI_BASE_URL", "https://example.invalid/v1")
|
||||
save_env_value("OPENAI_API_KEY", "custom-key")
|
||||
|
||||
config = load_config()
|
||||
config["model"] = {
|
||||
@@ -50,7 +52,7 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m
|
||||
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_prompt_choice(_question, choices, default=0):
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
assert choices[-1] == "Keep current (Custom: https://example.invalid/v1)"
|
||||
@@ -88,13 +90,17 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm
|
||||
captured = {"provider_choices": None, "model_choices": None}
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_prompt_choice(_question, choices, default=0):
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
captured["provider_choices"] = list(choices)
|
||||
assert choices[-1] == "Keep current (Anthropic)"
|
||||
return len(choices) - 1
|
||||
if calls["count"] == 2:
|
||||
assert question == "Configure vision:"
|
||||
assert choices[-1] == "Skip for now"
|
||||
return len(choices) - 1
|
||||
if calls["count"] == 3:
|
||||
captured["model_choices"] = list(choices)
|
||||
return len(choices) - 1 # keep current model
|
||||
raise AssertionError("Unexpected extra prompt_choice call")
|
||||
@@ -105,6 +111,7 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.models.provider_model_ids", lambda provider: [])
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
@@ -113,7 +120,44 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm
|
||||
assert captured["model_choices"] is not None
|
||||
assert captured["model_choices"][0] == "claude-opus-4-6"
|
||||
assert "anthropic/claude-opus-4.6 (recommended)" not in captured["model_choices"]
|
||||
assert calls["count"] == 2
|
||||
assert calls["count"] == 3
|
||||
|
||||
|
||||
def test_setup_keep_current_anthropic_can_configure_openai_vision_default(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
config["model"] = {
|
||||
"default": "claude-opus-4-6",
|
||||
"provider": "anthropic",
|
||||
}
|
||||
save_config(config)
|
||||
|
||||
picks = iter([
|
||||
9, # keep current provider
|
||||
1, # configure vision with OpenAI
|
||||
5, # use default gpt-4o-mini vision model
|
||||
4, # keep current Anthropic model
|
||||
])
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: next(picks))
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.setup.prompt",
|
||||
lambda message, *args, **kwargs: "sk-openai" if "OpenAI API key" in message else "",
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.models.provider_model_ids", lambda provider: [])
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
|
||||
setup_model_provider(config)
|
||||
env = _read_env(tmp_path)
|
||||
|
||||
assert env.get("OPENAI_API_KEY") == "sk-openai"
|
||||
assert env.get("OPENAI_BASE_URL") == "https://api.openai.com/v1"
|
||||
assert env.get("AUXILIARY_VISION_MODEL") == "gpt-4o-mini"
|
||||
|
||||
|
||||
def test_setup_switch_custom_to_codex_clears_custom_endpoint_and_updates_config(tmp_path, monkeypatch):
|
||||
@@ -144,7 +188,7 @@ def test_setup_switch_custom_to_codex_clears_custom_endpoint_and_updates_config(
|
||||
"hermes_cli.auth.resolve_codex_runtime_credentials",
|
||||
lambda *args, **kwargs: {
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "codex-access-token",
|
||||
"api_key": "codex-...oken",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
@@ -163,3 +207,36 @@ def test_setup_switch_custom_to_codex_clears_custom_endpoint_and_updates_config(
|
||||
assert reloaded["model"]["provider"] == "openai-codex"
|
||||
assert reloaded["model"]["default"] == "openai/gpt-5.3-codex"
|
||||
assert reloaded["model"]["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
def test_setup_summary_marks_codex_auth_as_vision_available(tmp_path, monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
(tmp_path / "auth.json").write_text(
|
||||
'{"active_provider":"openai-codex","providers":{"openai-codex":{"tokens":{"access_token": "***", "refresh_token": "***"}}}}'
|
||||
)
|
||||
|
||||
monkeypatch.setattr("shutil.which", lambda _name: None)
|
||||
|
||||
_print_setup_summary(load_config(), tmp_path)
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Vision (image analysis)" in output
|
||||
assert "missing run 'hermes setup' to configure" not in output
|
||||
assert "Mixture of Agents" in output
|
||||
assert "missing OPENROUTER_API_KEY" in output
|
||||
|
||||
|
||||
def test_setup_summary_marks_anthropic_auth_as_vision_available(tmp_path, monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
monkeypatch.setattr("shutil.which", lambda _name: None)
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: ["anthropic"])
|
||||
|
||||
_print_setup_summary(load_config(), tmp_path)
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Vision (image analysis)" in output
|
||||
assert "missing run 'hermes setup' to configure" not in output
|
||||
|
||||
@@ -3,7 +3,7 @@ from io import StringIO
|
||||
import pytest
|
||||
from rich.console import Console
|
||||
|
||||
from hermes_cli.skills_hub import do_check, do_list, do_update
|
||||
from hermes_cli.skills_hub import do_check, do_list, do_update, handle_skills_slash
|
||||
|
||||
|
||||
class _DummyLockFile:
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def test_cli_skills_install_accepts_yes_alias(monkeypatch):
|
||||
from hermes_cli.main import main
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_skills_command(args):
|
||||
captured["identifier"] = args.identifier
|
||||
captured["force"] = args.force
|
||||
|
||||
monkeypatch.setattr("hermes_cli.skills_hub.skills_command", fake_skills_command)
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
["hermes", "skills", "install", "official/email/agentmail", "--yes"],
|
||||
)
|
||||
|
||||
main()
|
||||
|
||||
assert captured == {
|
||||
"identifier": "official/email/agentmail",
|
||||
"force": True,
|
||||
}
|
||||
@@ -1,6 +1,13 @@
|
||||
"""Tests for hermes_cli.tools_config platform tool persistence."""
|
||||
|
||||
from hermes_cli.tools_config import _get_platform_tools, _platform_toolset_summary
|
||||
from unittest.mock import patch
|
||||
|
||||
from hermes_cli.tools_config import (
|
||||
_get_platform_tools,
|
||||
_platform_toolset_summary,
|
||||
_save_platform_tools,
|
||||
_toolset_has_keys,
|
||||
)
|
||||
|
||||
|
||||
def test_get_platform_tools_uses_default_when_platform_not_configured():
|
||||
@@ -26,3 +33,70 @@ def test_platform_toolset_summary_uses_explicit_platform_list():
|
||||
|
||||
assert set(summary.keys()) == {"cli"}
|
||||
assert summary["cli"] == _get_platform_tools(config, "cli")
|
||||
|
||||
|
||||
def test_toolset_has_keys_for_vision_accepts_codex_auth(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "auth.json").write_text(
|
||||
'{"active_provider":"openai-codex","providers":{"openai-codex":{"tokens":{"access_token": "codex-...oken","refresh_token": "codex-...oken"}}}}'
|
||||
)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("AUXILIARY_VISION_PROVIDER", raising=False)
|
||||
monkeypatch.delenv("CONTEXT_VISION_PROVIDER", raising=False)
|
||||
|
||||
assert _toolset_has_keys("vision") is True
|
||||
|
||||
|
||||
def test_save_platform_tools_preserves_mcp_server_names():
|
||||
"""Ensure MCP server names are preserved when saving platform tools.
|
||||
|
||||
Regression test for https://github.com/NousResearch/hermes-agent/issues/1247
|
||||
"""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"cli": ["web", "terminal", "time", "github", "custom-mcp-server"]
|
||||
}
|
||||
}
|
||||
|
||||
new_selection = {"web", "browser"}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "cli", new_selection)
|
||||
|
||||
saved_toolsets = config["platform_toolsets"]["cli"]
|
||||
|
||||
assert "time" in saved_toolsets
|
||||
assert "github" in saved_toolsets
|
||||
assert "custom-mcp-server" in saved_toolsets
|
||||
assert "web" in saved_toolsets
|
||||
assert "browser" in saved_toolsets
|
||||
assert "terminal" not in saved_toolsets
|
||||
|
||||
|
||||
def test_save_platform_tools_handles_empty_existing_config():
|
||||
"""Saving platform tools works when no existing config exists."""
|
||||
config = {}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "telegram", {"web", "terminal"})
|
||||
|
||||
saved_toolsets = config["platform_toolsets"]["telegram"]
|
||||
assert "web" in saved_toolsets
|
||||
assert "terminal" in saved_toolsets
|
||||
|
||||
|
||||
def test_save_platform_tools_handles_invalid_existing_config():
|
||||
"""Saving platform tools works when existing config is not a list."""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"cli": "invalid-string-value"
|
||||
}
|
||||
}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "cli", {"web"})
|
||||
|
||||
saved_toolsets = config["platform_toolsets"]["cli"]
|
||||
assert "web" in saved_toolsets
|
||||
|
||||
@@ -46,6 +46,20 @@ def test_stash_local_changes_if_needed_returns_specific_stash_commit(monkeypatch
|
||||
assert calls[2][0][-3:] == ["rev-parse", "--verify", "refs/stash"]
|
||||
|
||||
|
||||
def test_resolve_stash_selector_returns_matching_entry(monkeypatch, tmp_path):
|
||||
def fake_run(cmd, **kwargs):
|
||||
assert cmd == ["git", "stash", "list", "--format=%gd %H"]
|
||||
return SimpleNamespace(
|
||||
stdout="stash@{0} def456\nstash@{1} abc123\n",
|
||||
returncode=0,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
|
||||
assert hermes_main._resolve_stash_selector(["git"], tmp_path, "abc123") == "stash@{1}"
|
||||
|
||||
|
||||
|
||||
def test_restore_stashed_changes_prompts_before_applying(monkeypatch, tmp_path, capsys):
|
||||
calls = []
|
||||
|
||||
@@ -53,6 +67,8 @@ def test_restore_stashed_changes_prompts_before_applying(monkeypatch, tmp_path,
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="applied\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "list"]:
|
||||
return SimpleNamespace(stdout="stash@{1} abc123\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "drop"]:
|
||||
return SimpleNamespace(stdout="dropped\n", stderr="", returncode=0)
|
||||
raise AssertionError(f"unexpected command: {cmd}")
|
||||
@@ -64,7 +80,8 @@ def test_restore_stashed_changes_prompts_before_applying(monkeypatch, tmp_path,
|
||||
|
||||
assert restored is True
|
||||
assert calls[0][0] == ["git", "stash", "apply", "abc123"]
|
||||
assert calls[1][0] == ["git", "stash", "drop", "abc123"]
|
||||
assert calls[1][0] == ["git", "stash", "list", "--format=%gd %H"]
|
||||
assert calls[2][0] == ["git", "stash", "drop", "stash@{1}"]
|
||||
out = capsys.readouterr().out
|
||||
assert "Restore local changes now? [Y/n]" in out
|
||||
assert "restored on top of the updated codebase" in out
|
||||
@@ -99,6 +116,8 @@ def test_restore_stashed_changes_applies_without_prompt_when_disabled(monkeypatc
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="applied\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "list"]:
|
||||
return SimpleNamespace(stdout="stash@{0} abc123\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "drop"]:
|
||||
return SimpleNamespace(stdout="dropped\n", stderr="", returncode=0)
|
||||
raise AssertionError(f"unexpected command: {cmd}")
|
||||
@@ -109,9 +128,78 @@ def test_restore_stashed_changes_applies_without_prompt_when_disabled(monkeypatc
|
||||
|
||||
assert restored is True
|
||||
assert calls[0][0] == ["git", "stash", "apply", "abc123"]
|
||||
assert calls[1][0] == ["git", "stash", "list", "--format=%gd %H"]
|
||||
assert calls[2][0] == ["git", "stash", "drop", "stash@{0}"]
|
||||
assert "Restore local changes now?" not in capsys.readouterr().out
|
||||
|
||||
|
||||
|
||||
def test_print_stash_cleanup_guidance_with_selector(capsys):
|
||||
hermes_main._print_stash_cleanup_guidance("abc123", "stash@{2}")
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Check `git status` first" in out
|
||||
assert "git stash list --format='%gd %H %s'" in out
|
||||
assert "git stash drop stash@{2}" in out
|
||||
|
||||
|
||||
|
||||
def test_restore_stashed_changes_keeps_going_when_stash_entry_cannot_be_resolved(monkeypatch, tmp_path, capsys):
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="applied\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "list"]:
|
||||
return SimpleNamespace(stdout="stash@{0} def456\n", stderr="", returncode=0)
|
||||
raise AssertionError(f"unexpected command: {cmd}")
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
|
||||
restored = hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=False)
|
||||
|
||||
assert restored is True
|
||||
assert calls == [
|
||||
(["git", "stash", "apply", "abc123"], {"cwd": tmp_path, "capture_output": True, "text": True}),
|
||||
(["git", "stash", "list", "--format=%gd %H"], {"cwd": tmp_path, "capture_output": True, "text": True, "check": True}),
|
||||
]
|
||||
out = capsys.readouterr().out
|
||||
assert "couldn't find the stash entry to drop" in out
|
||||
assert "stash was left in place" in out
|
||||
assert "Check `git status` first" in out
|
||||
assert "git stash list --format='%gd %H %s'" in out
|
||||
assert "Look for commit abc123" in out
|
||||
|
||||
|
||||
|
||||
def test_restore_stashed_changes_keeps_going_when_drop_fails(monkeypatch, tmp_path, capsys):
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="applied\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "list"]:
|
||||
return SimpleNamespace(stdout="stash@{0} abc123\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "drop"]:
|
||||
return SimpleNamespace(stdout="", stderr="drop failed\n", returncode=1)
|
||||
raise AssertionError(f"unexpected command: {cmd}")
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
|
||||
restored = hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=False)
|
||||
|
||||
assert restored is True
|
||||
assert calls[2][0] == ["git", "stash", "drop", "stash@{0}"]
|
||||
out = capsys.readouterr().out
|
||||
assert "couldn't drop the saved stash entry" in out
|
||||
assert "drop failed" in out
|
||||
assert "Check `git status` first" in out
|
||||
assert "git stash list --format='%gd %H %s'" in out
|
||||
assert "git stash drop stash@{0}" in out
|
||||
|
||||
|
||||
def test_restore_stashed_changes_exits_cleanly_when_apply_fails(monkeypatch, tmp_path, capsys):
|
||||
calls = []
|
||||
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Tests for the update check mechanism in hermes_cli.banner."""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_version_string_no_v_prefix():
|
||||
"""__version__ should be bare semver without a 'v' prefix."""
|
||||
from hermes_cli import __version__
|
||||
assert not __version__.startswith("v"), f"__version__ should not start with 'v', got {__version__!r}"
|
||||
|
||||
|
||||
def test_check_for_updates_uses_cache(tmp_path):
|
||||
"""When cache is fresh, check_for_updates should return cached value without calling git."""
|
||||
from hermes_cli.banner import check_for_updates
|
||||
|
||||
# Create a fake git repo and fresh cache
|
||||
repo_dir = tmp_path / "hermes-agent"
|
||||
repo_dir.mkdir()
|
||||
(repo_dir / ".git").mkdir()
|
||||
|
||||
cache_file = tmp_path / ".update_check"
|
||||
cache_file.write_text(json.dumps({"ts": time.time(), "behind": 3}))
|
||||
|
||||
with patch("hermes_cli.banner.os.getenv", return_value=str(tmp_path)):
|
||||
with patch("hermes_cli.banner.subprocess.run") as mock_run:
|
||||
result = check_for_updates()
|
||||
|
||||
assert result == 3
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
def test_check_for_updates_expired_cache(tmp_path):
|
||||
"""When cache is expired, check_for_updates should call git fetch."""
|
||||
from hermes_cli.banner import check_for_updates
|
||||
|
||||
repo_dir = tmp_path / "hermes-agent"
|
||||
repo_dir.mkdir()
|
||||
(repo_dir / ".git").mkdir()
|
||||
|
||||
# Write an expired cache (timestamp far in the past)
|
||||
cache_file = tmp_path / ".update_check"
|
||||
cache_file.write_text(json.dumps({"ts": 0, "behind": 1}))
|
||||
|
||||
mock_result = MagicMock(returncode=0, stdout="5\n")
|
||||
|
||||
with patch("hermes_cli.banner.os.getenv", return_value=str(tmp_path)):
|
||||
with patch("hermes_cli.banner.subprocess.run", return_value=mock_result) as mock_run:
|
||||
result = check_for_updates()
|
||||
|
||||
assert result == 5
|
||||
assert mock_run.call_count == 2 # git fetch + git rev-list
|
||||
|
||||
|
||||
def test_check_for_updates_no_git_dir(tmp_path):
|
||||
"""Returns None when .git directory doesn't exist anywhere."""
|
||||
import hermes_cli.banner as banner
|
||||
|
||||
# Create a fake banner.py so the fallback path also has no .git
|
||||
fake_banner = tmp_path / "hermes_cli" / "banner.py"
|
||||
fake_banner.parent.mkdir(parents=True, exist_ok=True)
|
||||
fake_banner.touch()
|
||||
|
||||
original = banner.__file__
|
||||
try:
|
||||
banner.__file__ = str(fake_banner)
|
||||
with patch("hermes_cli.banner.os.getenv", return_value=str(tmp_path)):
|
||||
with patch("hermes_cli.banner.subprocess.run") as mock_run:
|
||||
result = banner.check_for_updates()
|
||||
assert result is None
|
||||
mock_run.assert_not_called()
|
||||
finally:
|
||||
banner.__file__ = original
|
||||
|
||||
|
||||
def test_check_for_updates_fallback_to_project_root():
|
||||
"""Dev install: falls back to Path(__file__).parent.parent when HERMES_HOME has no git repo."""
|
||||
import hermes_cli.banner as banner
|
||||
|
||||
project_root = Path(banner.__file__).parent.parent.resolve()
|
||||
if not (project_root / ".git").exists():
|
||||
pytest.skip("Not running from a git checkout")
|
||||
|
||||
# Point HERMES_HOME at a temp dir with no hermes-agent/.git
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
with patch("hermes_cli.banner.os.getenv", return_value=td):
|
||||
with patch("hermes_cli.banner.subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="0\n")
|
||||
result = banner.check_for_updates()
|
||||
# Should have fallen back to project root and run git commands
|
||||
assert mock_run.call_count >= 1
|
||||
|
||||
|
||||
def test_prefetch_non_blocking():
|
||||
"""prefetch_update_check() should return immediately without blocking."""
|
||||
import hermes_cli.banner as banner
|
||||
|
||||
# Reset module state
|
||||
banner._update_result = None
|
||||
banner._update_check_done = threading.Event()
|
||||
|
||||
with patch.object(banner, "check_for_updates", return_value=5):
|
||||
start = time.monotonic()
|
||||
banner.prefetch_update_check()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should return almost immediately (well under 1 second)
|
||||
assert elapsed < 1.0
|
||||
|
||||
# Wait for the background thread to finish
|
||||
banner._update_check_done.wait(timeout=5)
|
||||
assert banner._update_result == 5
|
||||
|
||||
|
||||
def test_get_update_result_timeout():
|
||||
"""get_update_result() returns None when check hasn't completed within timeout."""
|
||||
import hermes_cli.banner as banner
|
||||
|
||||
# Reset module state — don't set the event
|
||||
banner._update_result = None
|
||||
banner._update_check_done = threading.Event()
|
||||
|
||||
start = time.monotonic()
|
||||
result = banner.get_update_result(timeout=0.1)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should have waited ~0.1s and returned None
|
||||
assert result is None
|
||||
assert elapsed < 0.5
|
||||
@@ -0,0 +1,611 @@
|
||||
"""Integration tests for Discord voice channel audio flow.
|
||||
|
||||
Uses real NaCl encryption and Opus codec (no mocks for crypto/codec).
|
||||
Does NOT require a Discord connection — tests the VoiceReceiver
|
||||
packet processing pipeline end-to-end.
|
||||
|
||||
Requires: PyNaCl>=1.5.0, discord.py[voice] (opus codec)
|
||||
"""
|
||||
|
||||
import struct
|
||||
import time
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
# Skip entire module if voice deps are missing
|
||||
pytest.importorskip("nacl.secret", reason="PyNaCl required for voice integration tests")
|
||||
discord = pytest.importorskip("discord", reason="discord.py required for voice integration tests")
|
||||
|
||||
import nacl.secret
|
||||
|
||||
try:
|
||||
if not discord.opus.is_loaded():
|
||||
import ctypes.util
|
||||
opus_path = ctypes.util.find_library("opus")
|
||||
if not opus_path:
|
||||
import sys
|
||||
for p in ("/opt/homebrew/lib/libopus.dylib", "/usr/local/lib/libopus.dylib"):
|
||||
import os
|
||||
if os.path.isfile(p):
|
||||
opus_path = p
|
||||
break
|
||||
if opus_path:
|
||||
discord.opus.load_opus(opus_path)
|
||||
OPUS_AVAILABLE = discord.opus.is_loaded()
|
||||
except Exception:
|
||||
OPUS_AVAILABLE = False
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_secret_key():
|
||||
"""Generate a random 32-byte key."""
|
||||
import os
|
||||
return os.urandom(32)
|
||||
|
||||
|
||||
def _build_encrypted_rtp_packet(secret_key, opus_payload, ssrc=100, seq=1, timestamp=960):
|
||||
"""Build a real NaCl-encrypted RTP packet matching Discord's format.
|
||||
|
||||
Format: RTP header (12 bytes) + encrypted(opus) + 4-byte nonce
|
||||
Encryption: aead_xchacha20_poly1305 with RTP header as AAD.
|
||||
"""
|
||||
# RTP header: version=2, payload_type=0x78, no extension, no CSRC
|
||||
header = struct.pack(">BBHII", 0x80, 0x78, seq, timestamp, ssrc)
|
||||
|
||||
# Encrypt with NaCl AEAD
|
||||
box = nacl.secret.Aead(secret_key)
|
||||
nonce_counter = struct.pack(">I", seq) # 4-byte counter as nonce seed
|
||||
# Full 24-byte nonce: counter in first 4 bytes, rest zeros
|
||||
full_nonce = nonce_counter + b'\x00' * 20
|
||||
|
||||
enc_msg = box.encrypt(opus_payload, header, full_nonce)
|
||||
ciphertext = enc_msg.ciphertext # without nonce prefix
|
||||
|
||||
# Discord format: header + ciphertext + 4-byte nonce
|
||||
return header + ciphertext + nonce_counter
|
||||
|
||||
|
||||
def _make_voice_receiver(secret_key, dave_session=None, bot_ssrc=9999,
|
||||
allowed_user_ids=None, members=None):
|
||||
"""Create a VoiceReceiver with real secret key."""
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = list(secret_key)
|
||||
vc._connection.dave_session = dave_session
|
||||
vc._connection.ssrc = bot_ssrc
|
||||
vc._connection.add_socket_listener = MagicMock()
|
||||
vc._connection.remove_socket_listener = MagicMock()
|
||||
vc._connection.hook = None
|
||||
vc.user = SimpleNamespace(id=bot_ssrc)
|
||||
vc.channel = MagicMock()
|
||||
vc.channel.members = members or []
|
||||
receiver = VoiceReceiver(vc, allowed_user_ids=allowed_user_ids)
|
||||
receiver.start()
|
||||
return receiver
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRealNaClDecrypt:
|
||||
"""End-to-end: real NaCl encrypt → _on_packet decrypt → buffer."""
|
||||
|
||||
def test_valid_encrypted_packet_buffered(self):
|
||||
"""Real NaCl encrypted packet → decrypted → buffered."""
|
||||
key = _make_secret_key()
|
||||
opus_silence = b'\xf8\xff\xfe'
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, opus_silence, ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_wrong_key_packet_dropped(self):
|
||||
"""Packet encrypted with wrong key → NaCl fails → not buffered."""
|
||||
real_key = _make_secret_key()
|
||||
wrong_key = _make_secret_key()
|
||||
opus_silence = b'\xf8\xff\xfe'
|
||||
receiver = _make_voice_receiver(real_key)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(wrong_key, opus_silence, ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
def test_bot_ssrc_ignored(self):
|
||||
"""Packet from bot's own SSRC → ignored."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key, bot_ssrc=9999)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=9999)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_multiple_packets_accumulate(self):
|
||||
"""Multiple valid packets → buffer grows."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
for seq in range(1, 6):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert 100 in receiver._buffers
|
||||
buf_size = len(receiver._buffers[100])
|
||||
assert buf_size > 0, "Multiple packets should accumulate in buffer"
|
||||
|
||||
def test_different_ssrcs_separate_buffers(self):
|
||||
"""Packets from different SSRCs → separate buffers."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
for ssrc in [100, 200, 300]:
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=ssrc)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers) == 3
|
||||
for ssrc in [100, 200, 300]:
|
||||
assert ssrc in receiver._buffers
|
||||
|
||||
|
||||
class TestRealNaClWithDAVE:
|
||||
"""NaCl decrypt + DAVE passthrough scenarios with real crypto."""
|
||||
|
||||
def test_dave_unknown_ssrc_passthrough(self):
|
||||
"""DAVE enabled but SSRC unknown → skip DAVE, buffer audio."""
|
||||
key = _make_secret_key()
|
||||
dave = MagicMock() # DAVE session present but SSRC not mapped
|
||||
receiver = _make_voice_receiver(key, dave_session=dave)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
# DAVE decrypt not called (SSRC unknown)
|
||||
dave.decrypt.assert_not_called()
|
||||
# Audio still buffered via passthrough
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_dave_unencrypted_error_passthrough(self):
|
||||
"""DAVE raises 'Unencrypted' → use NaCl-decrypted data as-is."""
|
||||
key = _make_secret_key()
|
||||
dave = MagicMock()
|
||||
dave.decrypt.side_effect = Exception(
|
||||
"DecryptionFailed(UnencryptedWhenPassthroughDisabled)"
|
||||
)
|
||||
receiver = _make_voice_receiver(key, dave_session=dave)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
# DAVE was called but failed → passthrough
|
||||
dave.decrypt.assert_called_once()
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_dave_real_error_drops(self):
|
||||
"""DAVE raises non-Unencrypted error → packet dropped."""
|
||||
key = _make_secret_key()
|
||||
dave = MagicMock()
|
||||
dave.decrypt.side_effect = Exception("KeyRotationFailed")
|
||||
receiver = _make_voice_receiver(key, dave_session=dave)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
|
||||
class TestFullVoiceFlow:
|
||||
"""End-to-end: encrypt → receive → buffer → silence detect → complete."""
|
||||
|
||||
def test_single_utterance_flow(self):
|
||||
"""Encrypt packets → buffer → silence → check_silence returns utterance."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
# Send enough packets to exceed MIN_SPEECH_DURATION (0.5s)
|
||||
# At 48kHz stereo 16-bit, each Opus silence frame decodes to ~3840 bytes
|
||||
# Need 96000 bytes = ~25 frames
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
# Simulate silence by setting last_packet_time in the past
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
user_id, pcm_data = completed[0]
|
||||
assert user_id == 42
|
||||
assert len(pcm_data) > 0
|
||||
|
||||
def test_utterance_with_ssrc_automap(self):
|
||||
"""No SPEAKING event → auto-map sole allowed user → utterance processed."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids={"42"}, members=members
|
||||
)
|
||||
# No map_ssrc call — simulating missing SPEAKING event
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42 # auto-mapped to sole allowed user
|
||||
|
||||
def test_pause_blocks_during_playback(self):
|
||||
"""Pause receiver → packets ignored → resume → packets accepted."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
# Pause (echo prevention during TTS playback)
|
||||
receiver.pause()
|
||||
packet = _build_encrypted_rtp_packet(key, b'\xf8\xff\xfe', ssrc=100)
|
||||
receiver._on_packet(packet)
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
# Resume
|
||||
receiver.resume()
|
||||
receiver._on_packet(packet)
|
||||
assert 100 in receiver._buffers
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
def test_corrupted_packet_ignored(self):
|
||||
"""Corrupted/truncated packet → silently ignored."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
|
||||
# Too short
|
||||
receiver._on_packet(b"\x00" * 5)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
# Wrong RTP version
|
||||
bad_header = struct.pack(">BBHII", 0x00, 0x78, 1, 960, 100)
|
||||
receiver._on_packet(bad_header + b"\x00" * 20)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
# Wrong payload type
|
||||
bad_pt = struct.pack(">BBHII", 0x80, 0x00, 1, 960, 100)
|
||||
receiver._on_packet(bad_pt + b"\x00" * 20)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_stop_cleans_everything(self):
|
||||
"""stop() clears all state cleanly."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
|
||||
receiver.stop()
|
||||
assert receiver._running is False
|
||||
assert len(receiver._buffers) == 0
|
||||
assert len(receiver._ssrc_to_user) == 0
|
||||
assert len(receiver._decoders) == 0
|
||||
|
||||
|
||||
class TestSPEAKINGHook:
|
||||
"""SPEAKING event hook correctly maps SSRC to user_id."""
|
||||
|
||||
def test_speaking_hook_installed(self):
|
||||
"""start() installs speaking hook on connection."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
conn = receiver._vc._connection
|
||||
# hook should be set (wrapped)
|
||||
assert conn.hook is not None
|
||||
|
||||
def test_map_ssrc_via_speaking(self):
|
||||
"""SPEAKING op 5 event maps SSRC to user_id."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(500, 12345)
|
||||
assert receiver._ssrc_to_user[500] == 12345
|
||||
|
||||
def test_map_ssrc_overwrites(self):
|
||||
"""New SPEAKING event for same SSRC overwrites old mapping."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(500, 111)
|
||||
receiver.map_ssrc(500, 222)
|
||||
assert receiver._ssrc_to_user[500] == 222
|
||||
|
||||
def test_speaking_mapped_audio_processed(self):
|
||||
"""After SSRC is mapped, audio from that SSRC gets correct user_id."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
|
||||
class TestAuthFiltering:
|
||||
"""Only allowed users' audio should be processed."""
|
||||
|
||||
def test_allowed_user_audio_processed(self):
|
||||
"""Allowed user's utterance is returned by check_silence."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids={"42"}, members=members,
|
||||
)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
def test_automap_rejects_unallowed_user(self):
|
||||
"""Auto-map refuses to map SSRC to user not in allowed list."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids={"99"}, # Alice not allowed
|
||||
members=members,
|
||||
)
|
||||
# No map_ssrc — SSRC unknown, auto-map should reject
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_empty_allowlist_allows_all(self):
|
||||
"""Empty allowed_user_ids means no restriction."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
receiver = _make_voice_receiver(
|
||||
key, allowed_user_ids=None, members=members,
|
||||
)
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
# Auto-mapped to sole non-bot member
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
|
||||
class TestRejoinFlow:
|
||||
"""Leave and rejoin: state cleanup and fresh receiver."""
|
||||
|
||||
def test_stop_then_new_receiver_clean_state(self):
|
||||
"""After stop(), a new receiver starts with empty state."""
|
||||
key = _make_secret_key()
|
||||
receiver1 = _make_voice_receiver(key)
|
||||
receiver1.map_ssrc(100, 42)
|
||||
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver1._on_packet(packet)
|
||||
|
||||
assert len(receiver1._buffers[100]) > 0
|
||||
receiver1.stop()
|
||||
|
||||
# New receiver (simulates rejoin)
|
||||
receiver2 = _make_voice_receiver(key)
|
||||
assert len(receiver2._buffers) == 0
|
||||
assert len(receiver2._ssrc_to_user) == 0
|
||||
assert len(receiver2._decoders) == 0
|
||||
|
||||
def test_rejoin_new_ssrc_works(self):
|
||||
"""After rejoin, user may get new SSRC — still works."""
|
||||
key = _make_secret_key()
|
||||
receiver1 = _make_voice_receiver(key)
|
||||
receiver1.map_ssrc(100, 42) # old SSRC
|
||||
receiver1.stop()
|
||||
|
||||
receiver2 = _make_voice_receiver(key)
|
||||
receiver2.map_ssrc(200, 42) # new SSRC after rejoin
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=200, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver2._on_packet(packet)
|
||||
|
||||
receiver2._last_packet_time[200] = time.monotonic() - 3.0
|
||||
completed = receiver2.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
def test_rejoin_without_speaking_event_automap(self):
|
||||
"""Rejoin without SPEAKING event — auto-map sole allowed user."""
|
||||
key = _make_secret_key()
|
||||
members = [
|
||||
SimpleNamespace(id=9999, name="Bot"),
|
||||
SimpleNamespace(id=42, name="Alice"),
|
||||
]
|
||||
|
||||
# First session
|
||||
receiver1 = _make_voice_receiver(
|
||||
key, allowed_user_ids={"42"}, members=members,
|
||||
)
|
||||
receiver1.stop()
|
||||
|
||||
# Rejoin — new key (Discord may assign new secret_key)
|
||||
new_key = _make_secret_key()
|
||||
receiver2 = _make_voice_receiver(
|
||||
new_key, allowed_user_ids={"42"}, members=members,
|
||||
)
|
||||
# No map_ssrc — simulating missing SPEAKING event
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
new_key, b'\xf8\xff\xfe', ssrc=300, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver2._on_packet(packet)
|
||||
|
||||
receiver2._last_packet_time[300] = time.monotonic() - 3.0
|
||||
completed = receiver2.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
|
||||
|
||||
class TestMultiGuildIsolation:
|
||||
"""Each guild has independent voice state."""
|
||||
|
||||
def test_separate_receivers_independent(self):
|
||||
"""Two receivers (different guilds) don't interfere."""
|
||||
key1 = _make_secret_key()
|
||||
key2 = _make_secret_key()
|
||||
|
||||
receiver1 = _make_voice_receiver(key1, bot_ssrc=1111)
|
||||
receiver2 = _make_voice_receiver(key2, bot_ssrc=2222)
|
||||
|
||||
receiver1.map_ssrc(100, 42)
|
||||
receiver2.map_ssrc(200, 99)
|
||||
|
||||
# Send to receiver1
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key1, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver1._on_packet(packet)
|
||||
|
||||
# receiver2 should be empty
|
||||
assert len(receiver2._buffers) == 0
|
||||
assert 100 in receiver1._buffers
|
||||
|
||||
def test_stop_one_doesnt_affect_other(self):
|
||||
"""Stopping one receiver doesn't affect another."""
|
||||
key1 = _make_secret_key()
|
||||
key2 = _make_secret_key()
|
||||
|
||||
receiver1 = _make_voice_receiver(key1)
|
||||
receiver2 = _make_voice_receiver(key2)
|
||||
|
||||
receiver1.map_ssrc(100, 42)
|
||||
receiver2.map_ssrc(200, 99)
|
||||
|
||||
for seq in range(1, 10):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key2, b'\xf8\xff\xfe', ssrc=200, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver2._on_packet(packet)
|
||||
|
||||
receiver1.stop()
|
||||
|
||||
# receiver2 still has data
|
||||
assert receiver2._running is True
|
||||
assert len(receiver2._buffers[200]) > 0
|
||||
|
||||
|
||||
class TestEchoPreventionFlow:
|
||||
"""Receiver pause/resume during TTS playback prevents echo."""
|
||||
|
||||
def test_audio_during_pause_ignored(self):
|
||||
"""Audio arriving while paused is completely ignored."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
receiver.pause()
|
||||
|
||||
for seq in range(1, 30):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
def test_audio_after_resume_processed(self):
|
||||
"""Audio arriving after resume is processed normally."""
|
||||
key = _make_secret_key()
|
||||
receiver = _make_voice_receiver(key)
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
# Pause → send packets → resume → send more packets
|
||||
receiver.pause()
|
||||
for seq in range(1, 5):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
assert len(receiver._buffers.get(100, b"")) == 0
|
||||
|
||||
receiver.resume()
|
||||
for seq in range(5, 35):
|
||||
packet = _build_encrypted_rtp_packet(
|
||||
key, b'\xf8\xff\xfe', ssrc=100, seq=seq, timestamp=960 * seq
|
||||
)
|
||||
receiver._on_packet(packet)
|
||||
|
||||
assert len(receiver._buffers[100]) > 0
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
assert completed[0][0] == 42
|
||||
@@ -0,0 +1,203 @@
|
||||
"""Regression tests for Google Workspace OAuth setup.
|
||||
|
||||
These tests cover the headless/manual auth-code flow where the browser step and
|
||||
code exchange happen in separate process invocations.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
SCRIPT_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "skills/productivity/google-workspace/scripts/setup.py"
|
||||
)
|
||||
|
||||
|
||||
class FakeCredentials:
|
||||
def __init__(self, payload=None):
|
||||
self._payload = payload or {
|
||||
"token": "access-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
"scopes": ["scope-a"],
|
||||
}
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self._payload)
|
||||
|
||||
|
||||
class FakeFlow:
|
||||
created = []
|
||||
default_state = "generated-state"
|
||||
default_verifier = "generated-code-verifier"
|
||||
credentials_payload = None
|
||||
fetch_error = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_secrets_file,
|
||||
scopes,
|
||||
*,
|
||||
redirect_uri=None,
|
||||
state=None,
|
||||
code_verifier=None,
|
||||
autogenerate_code_verifier=False,
|
||||
):
|
||||
self.client_secrets_file = client_secrets_file
|
||||
self.scopes = scopes
|
||||
self.redirect_uri = redirect_uri
|
||||
self.state = state
|
||||
self.code_verifier = code_verifier
|
||||
self.autogenerate_code_verifier = autogenerate_code_verifier
|
||||
self.authorization_kwargs = None
|
||||
self.fetch_token_calls = []
|
||||
self.credentials = FakeCredentials(self.credentials_payload)
|
||||
|
||||
if autogenerate_code_verifier and not self.code_verifier:
|
||||
self.code_verifier = self.default_verifier
|
||||
if not self.state:
|
||||
self.state = self.default_state
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
cls.created = []
|
||||
cls.default_state = "generated-state"
|
||||
cls.default_verifier = "generated-code-verifier"
|
||||
cls.credentials_payload = None
|
||||
cls.fetch_error = None
|
||||
|
||||
@classmethod
|
||||
def from_client_secrets_file(cls, client_secrets_file, scopes, **kwargs):
|
||||
inst = cls(client_secrets_file, scopes, **kwargs)
|
||||
cls.created.append(inst)
|
||||
return inst
|
||||
|
||||
def authorization_url(self, **kwargs):
|
||||
self.authorization_kwargs = kwargs
|
||||
return f"https://auth.example/authorize?state={self.state}", self.state
|
||||
|
||||
def fetch_token(self, **kwargs):
|
||||
self.fetch_token_calls.append(kwargs)
|
||||
if self.fetch_error:
|
||||
raise self.fetch_error
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_module(monkeypatch, tmp_path):
|
||||
FakeFlow.reset()
|
||||
|
||||
google_auth_module = types.ModuleType("google_auth_oauthlib")
|
||||
flow_module = types.ModuleType("google_auth_oauthlib.flow")
|
||||
flow_module.Flow = FakeFlow
|
||||
google_auth_module.flow = flow_module
|
||||
monkeypatch.setitem(sys.modules, "google_auth_oauthlib", google_auth_module)
|
||||
monkeypatch.setitem(sys.modules, "google_auth_oauthlib.flow", flow_module)
|
||||
|
||||
spec = importlib.util.spec_from_file_location("google_workspace_setup_test", SCRIPT_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
monkeypatch.setattr(module, "_ensure_deps", lambda: None)
|
||||
monkeypatch.setattr(module, "CLIENT_SECRET_PATH", tmp_path / "google_client_secret.json")
|
||||
monkeypatch.setattr(module, "TOKEN_PATH", tmp_path / "google_token.json")
|
||||
monkeypatch.setattr(module, "PENDING_AUTH_PATH", tmp_path / "google_oauth_pending.json", raising=False)
|
||||
|
||||
client_secret = {
|
||||
"installed": {
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
}
|
||||
}
|
||||
module.CLIENT_SECRET_PATH.write_text(json.dumps(client_secret))
|
||||
return module
|
||||
|
||||
|
||||
class TestGetAuthUrl:
|
||||
def test_persists_state_and_code_verifier_for_later_exchange(self, setup_module, capsys):
|
||||
setup_module.get_auth_url()
|
||||
|
||||
out = capsys.readouterr().out.strip()
|
||||
assert out == "https://auth.example/authorize?state=generated-state"
|
||||
|
||||
saved = json.loads(setup_module.PENDING_AUTH_PATH.read_text())
|
||||
assert saved["state"] == "generated-state"
|
||||
assert saved["code_verifier"] == "generated-code-verifier"
|
||||
|
||||
flow = FakeFlow.created[-1]
|
||||
assert flow.autogenerate_code_verifier is True
|
||||
assert flow.authorization_kwargs == {"access_type": "offline", "prompt": "consent"}
|
||||
|
||||
|
||||
class TestExchangeAuthCode:
|
||||
def test_reuses_saved_pkce_material_for_plain_code(self, setup_module):
|
||||
setup_module.PENDING_AUTH_PATH.write_text(
|
||||
json.dumps({"state": "saved-state", "code_verifier": "saved-verifier"})
|
||||
)
|
||||
|
||||
setup_module.exchange_auth_code("4/test-auth-code")
|
||||
|
||||
flow = FakeFlow.created[-1]
|
||||
assert flow.state == "saved-state"
|
||||
assert flow.code_verifier == "saved-verifier"
|
||||
assert flow.fetch_token_calls == [{"code": "4/test-auth-code"}]
|
||||
assert json.loads(setup_module.TOKEN_PATH.read_text())["token"] == "access-token"
|
||||
assert not setup_module.PENDING_AUTH_PATH.exists()
|
||||
|
||||
def test_extracts_code_from_redirect_url_and_checks_state(self, setup_module):
|
||||
setup_module.PENDING_AUTH_PATH.write_text(
|
||||
json.dumps({"state": "saved-state", "code_verifier": "saved-verifier"})
|
||||
)
|
||||
|
||||
setup_module.exchange_auth_code(
|
||||
"http://localhost:1/?code=4/extracted-code&state=saved-state&scope=gmail"
|
||||
)
|
||||
|
||||
flow = FakeFlow.created[-1]
|
||||
assert flow.fetch_token_calls == [{"code": "4/extracted-code"}]
|
||||
|
||||
def test_rejects_state_mismatch(self, setup_module, capsys):
|
||||
setup_module.PENDING_AUTH_PATH.write_text(
|
||||
json.dumps({"state": "saved-state", "code_verifier": "saved-verifier"})
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
setup_module.exchange_auth_code(
|
||||
"http://localhost:1/?code=4/extracted-code&state=wrong-state"
|
||||
)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "state mismatch" in out.lower()
|
||||
assert not setup_module.TOKEN_PATH.exists()
|
||||
|
||||
def test_requires_pending_auth_session(self, setup_module, capsys):
|
||||
with pytest.raises(SystemExit):
|
||||
setup_module.exchange_auth_code("4/test-auth-code")
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "run --auth-url first" in out.lower()
|
||||
assert not setup_module.TOKEN_PATH.exists()
|
||||
|
||||
def test_keeps_pending_auth_session_when_exchange_fails(self, setup_module, capsys):
|
||||
setup_module.PENDING_AUTH_PATH.write_text(
|
||||
json.dumps({"state": "saved-state", "code_verifier": "saved-verifier"})
|
||||
)
|
||||
FakeFlow.fetch_error = Exception("invalid_grant: Missing code verifier")
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
setup_module.exchange_auth_code("4/test-auth-code")
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "token exchange failed" in out.lower()
|
||||
assert setup_module.PENDING_AUTH_PATH.exists()
|
||||
assert not setup_module.TOKEN_PATH.exists()
|
||||
@@ -484,3 +484,22 @@ class TestResizeToolPool:
|
||||
"""resize_tool_pool should not raise."""
|
||||
resize_tool_pool(16) # Small pool for testing
|
||||
resize_tool_pool(128) # Restore default
|
||||
|
||||
def test_resize_shuts_down_previous_executor(self, monkeypatch):
|
||||
"""Replacing the global tool executor should shut down the old pool."""
|
||||
import environments.agent_loop as agent_loop_module
|
||||
|
||||
old_executor = MagicMock()
|
||||
new_executor = MagicMock()
|
||||
|
||||
monkeypatch.setattr(agent_loop_module, "_tool_executor", old_executor)
|
||||
monkeypatch.setattr(
|
||||
agent_loop_module.concurrent.futures,
|
||||
"ThreadPoolExecutor",
|
||||
MagicMock(return_value=new_executor),
|
||||
)
|
||||
|
||||
resize_tool_pool(16)
|
||||
|
||||
old_executor.shutdown.assert_called_once_with(wait=False)
|
||||
assert agent_loop_module._tool_executor is new_executor
|
||||
|
||||
@@ -16,6 +16,7 @@ from agent.anthropic_adapter import (
|
||||
build_anthropic_kwargs,
|
||||
convert_messages_to_anthropic,
|
||||
convert_tools_to_anthropic,
|
||||
get_anthropic_token_source,
|
||||
is_claude_code_token_valid,
|
||||
normalize_anthropic_response,
|
||||
normalize_model_name,
|
||||
@@ -87,16 +88,25 @@ class TestReadClaudeCodeCredentials:
|
||||
cred_file.parent.mkdir(parents=True)
|
||||
cred_file.write_text(json.dumps({
|
||||
"claudeAiOauth": {
|
||||
"accessToken": "sk-ant-oat01-test-token",
|
||||
"refreshToken": "sk-ant-ort01-refresh",
|
||||
"accessToken": "sk-ant-oat01-token",
|
||||
"refreshToken": "sk-ant-oat01-refresh",
|
||||
"expiresAt": int(time.time() * 1000) + 3600_000,
|
||||
}
|
||||
}))
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
creds = read_claude_code_credentials()
|
||||
assert creds is not None
|
||||
assert creds["accessToken"] == "sk-ant-oat01-test-token"
|
||||
assert creds["refreshToken"] == "sk-ant-ort01-refresh"
|
||||
assert creds["accessToken"] == "sk-ant-oat01-token"
|
||||
assert creds["refreshToken"] == "sk-ant-oat01-refresh"
|
||||
assert creds["source"] == "claude_code_credentials_file"
|
||||
|
||||
def test_ignores_primary_api_key_for_native_anthropic_resolution(self, tmp_path, monkeypatch):
|
||||
claude_json = tmp_path / ".claude.json"
|
||||
claude_json.write_text(json.dumps({"primaryApiKey": "sk-ant-api03-primary"}))
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
|
||||
creds = read_claude_code_credentials()
|
||||
assert creds is None
|
||||
|
||||
def test_returns_none_for_missing_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
@@ -139,6 +149,24 @@ class TestResolveAnthropicToken:
|
||||
monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-mytoken")
|
||||
assert resolve_anthropic_token() == "sk-ant-oat01-mytoken"
|
||||
|
||||
def test_reports_claude_json_primary_key_source(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
(tmp_path / ".claude.json").write_text(json.dumps({"primaryApiKey": "sk-ant-api03-primary"}))
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
|
||||
assert get_anthropic_token_source("sk-ant-api03-primary") == "claude_json_primary_api_key"
|
||||
|
||||
def test_does_not_resolve_primary_api_key_as_native_anthropic_token(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
(tmp_path / ".claude.json").write_text(json.dumps({"primaryApiKey": "sk-ant-api03-primary"}))
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
|
||||
assert resolve_anthropic_token() is None
|
||||
|
||||
def test_falls_back_to_api_key_when_no_oauth_sources_exist(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-mykey")
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
@@ -181,6 +209,33 @@ class TestResolveAnthropicToken:
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
assert resolve_anthropic_token() == "cc-auto-token"
|
||||
|
||||
def test_prefers_refreshable_claude_code_credentials_over_static_anthropic_token(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-static-token")
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
cred_file = tmp_path / ".claude" / ".credentials.json"
|
||||
cred_file.parent.mkdir(parents=True)
|
||||
cred_file.write_text(json.dumps({
|
||||
"claudeAiOauth": {
|
||||
"accessToken": "cc-auto-token",
|
||||
"refreshToken": "refresh-token",
|
||||
"expiresAt": int(time.time() * 1000) + 3600_000,
|
||||
}
|
||||
}))
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
|
||||
assert resolve_anthropic_token() == "cc-auto-token"
|
||||
|
||||
def test_keeps_static_anthropic_token_when_only_non_refreshable_claude_key_exists(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-static-token")
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
claude_json = tmp_path / ".claude.json"
|
||||
claude_json.write_text(json.dumps({"primaryApiKey": "sk-ant-api03-managed-key"}))
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
|
||||
assert resolve_anthropic_token() == "sk-ant-oat01-static-token"
|
||||
|
||||
|
||||
class TestRefreshOauthToken:
|
||||
def test_returns_none_without_refresh_token(self):
|
||||
@@ -279,6 +334,27 @@ class TestResolveWithRefresh:
|
||||
|
||||
assert result == "refreshed-token"
|
||||
|
||||
def test_static_env_oauth_token_does_not_block_refreshable_claude_creds(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-expired-env-token")
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
|
||||
cred_file = tmp_path / ".claude" / ".credentials.json"
|
||||
cred_file.parent.mkdir(parents=True)
|
||||
cred_file.write_text(json.dumps({
|
||||
"claudeAiOauth": {
|
||||
"accessToken": "expired-claude-creds-token",
|
||||
"refreshToken": "valid-refresh",
|
||||
"expiresAt": int(time.time() * 1000) - 3600_000,
|
||||
}
|
||||
}))
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
|
||||
with patch("agent.anthropic_adapter._refresh_oauth_token", return_value="refreshed-token"):
|
||||
result = resolve_anthropic_token()
|
||||
|
||||
assert result == "refreshed-token"
|
||||
|
||||
|
||||
class TestRunOauthSetupToken:
|
||||
def test_raises_when_claude_not_installed(self, monkeypatch):
|
||||
@@ -419,6 +495,59 @@ class TestConvertMessages:
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_converts_user_image_url_blocks_to_anthropic_image_blocks(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you see this?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/cat.png"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you see this?"},
|
||||
{"type": "image", "source": {"type": "url", "url": "https://example.com/cat.png"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def test_converts_data_url_image_blocks_to_base64_anthropic_image_blocks(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "input_text", "text": "What is in this screenshot?"},
|
||||
{"type": "input_image", "image_url": "data:image/png;base64,AAAA"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is in this screenshot?"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "AAAA",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def test_converts_tool_calls(self):
|
||||
messages = [
|
||||
{
|
||||
@@ -519,6 +648,56 @@ class TestConvertMessages:
|
||||
assert tool_block["content"] == "result"
|
||||
assert tool_block["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
def test_converts_data_url_image_to_anthropic_image_block(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64,ZmFrZQ=="},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
blocks = result[0]["content"]
|
||||
assert blocks[0] == {"type": "text", "text": "Describe this image"}
|
||||
assert blocks[1] == {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "ZmFrZQ==",
|
||||
},
|
||||
}
|
||||
|
||||
def test_converts_remote_image_url_to_anthropic_image_block(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/cat.png"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
blocks = result[0]["content"]
|
||||
assert blocks[1] == {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": "https://example.com/cat.png",
|
||||
},
|
||||
}
|
||||
|
||||
def test_empty_cached_assistant_tool_turn_converts_without_empty_text_block(self):
|
||||
messages = apply_anthropic_cache_control([
|
||||
{"role": "system", "content": "System prompt"},
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Tests for Anthropic OAuth setup flow behavior."""
|
||||
|
||||
from hermes_cli.config import load_env, save_env_value
|
||||
|
||||
|
||||
def test_run_anthropic_oauth_flow_prefers_claude_code_credentials(tmp_path, monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.run_oauth_setup_token",
|
||||
lambda: "sk-ant-oat01-from-claude-setup",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_claude_code_credentials",
|
||||
lambda: {
|
||||
"accessToken": "cc-access-token",
|
||||
"refreshToken": "cc-refresh-token",
|
||||
"expiresAt": 9999999999999,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.is_claude_code_token_valid",
|
||||
lambda creds: True,
|
||||
)
|
||||
|
||||
from hermes_cli.main import _run_anthropic_oauth_flow
|
||||
|
||||
save_env_value("ANTHROPIC_TOKEN", "stale-env-token")
|
||||
assert _run_anthropic_oauth_flow(save_env_value) is True
|
||||
|
||||
env_vars = load_env()
|
||||
assert env_vars["ANTHROPIC_TOKEN"] == ""
|
||||
assert env_vars["ANTHROPIC_API_KEY"] == ""
|
||||
output = capsys.readouterr().out
|
||||
assert "Claude Code credentials linked" in output
|
||||
|
||||
|
||||
def test_run_anthropic_oauth_flow_manual_token_still_persists(tmp_path, monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setattr("agent.anthropic_adapter.run_oauth_setup_token", lambda: None)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.is_claude_code_token_valid", lambda creds: False)
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": "sk-ant-oat01-manual-token")
|
||||
|
||||
from hermes_cli.main import _run_anthropic_oauth_flow
|
||||
|
||||
assert _run_anthropic_oauth_flow(save_env_value) is True
|
||||
|
||||
env_vars = load_env()
|
||||
assert env_vars["ANTHROPIC_TOKEN"] == "sk-ant-oat01-manual-token"
|
||||
output = capsys.readouterr().out
|
||||
assert "Setup-token saved" in output
|
||||
@@ -17,6 +17,21 @@ def test_save_anthropic_oauth_token_uses_token_slot_and_clears_api_key(tmp_path,
|
||||
assert env_vars["ANTHROPIC_API_KEY"] == ""
|
||||
|
||||
|
||||
def test_use_anthropic_claude_code_credentials_clears_env_slots(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from hermes_cli.config import save_anthropic_oauth_token, use_anthropic_claude_code_credentials
|
||||
|
||||
save_anthropic_oauth_token("sk-ant-oat01-token")
|
||||
use_anthropic_claude_code_credentials()
|
||||
|
||||
env_vars = load_env()
|
||||
assert env_vars["ANTHROPIC_TOKEN"] == ""
|
||||
assert env_vars["ANTHROPIC_API_KEY"] == ""
|
||||
|
||||
|
||||
def test_save_anthropic_api_key_uses_api_key_slot_and_clears_token(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
home.mkdir()
|
||||
@@ -24,8 +39,8 @@ def test_save_anthropic_api_key_uses_api_key_slot_and_clears_token(tmp_path, mon
|
||||
|
||||
from hermes_cli.config import save_anthropic_api_key
|
||||
|
||||
save_anthropic_api_key("sk-ant-api03-test-key")
|
||||
save_anthropic_api_key("sk-ant-api03-key")
|
||||
|
||||
env_vars = load_env()
|
||||
assert env_vars["ANTHROPIC_API_KEY"] == "sk-ant-api03-test-key"
|
||||
assert env_vars["ANTHROPIC_API_KEY"] == "sk-ant-api03-key"
|
||||
assert env_vars["ANTHROPIC_TOKEN"] == ""
|
||||
|
||||
@@ -426,3 +426,30 @@ class TestKimiCodeCredentialAutoDetect:
|
||||
monkeypatch.setenv("GLM_API_KEY", "sk-kimi-looks-like-kimi-but-isnt")
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["base_url"] == "https://api.z.ai/api/paas/v4"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi / Moonshot model list isolation tests
|
||||
# =============================================================================
|
||||
|
||||
class TestKimiMoonshotModelListIsolation:
|
||||
"""Moonshot (legacy) users must not see Coding Plan-only models."""
|
||||
|
||||
def test_moonshot_list_excludes_coding_plan_only_models(self):
|
||||
from hermes_cli.main import _PROVIDER_MODELS
|
||||
moonshot_models = _PROVIDER_MODELS["moonshot"]
|
||||
coding_plan_only = {"kimi-for-coding", "kimi-k2-thinking-turbo"}
|
||||
leaked = set(moonshot_models) & coding_plan_only
|
||||
assert not leaked, f"Moonshot list contains Coding Plan-only models: {leaked}"
|
||||
|
||||
def test_moonshot_list_contains_shared_models(self):
|
||||
from hermes_cli.main import _PROVIDER_MODELS
|
||||
moonshot_models = _PROVIDER_MODELS["moonshot"]
|
||||
assert "kimi-k2.5" in moonshot_models
|
||||
assert "kimi-k2-thinking" in moonshot_models
|
||||
|
||||
def test_coding_plan_list_contains_plan_specific_models(self):
|
||||
from hermes_cli.main import _PROVIDER_MODELS
|
||||
coding_models = _PROVIDER_MODELS["kimi-coding"]
|
||||
assert "kimi-for-coding" in coding_models
|
||||
assert "kimi-k2-thinking-turbo" in coding_models
|
||||
|
||||
@@ -25,7 +25,9 @@ def _run_auxiliary_bridge(config_dict, monkeypatch):
|
||||
# Clear env vars
|
||||
for key in (
|
||||
"AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL",
|
||||
"AUXILIARY_VISION_BASE_URL", "AUXILIARY_VISION_API_KEY",
|
||||
"AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL",
|
||||
"AUXILIARY_WEB_EXTRACT_BASE_URL", "AUXILIARY_WEB_EXTRACT_API_KEY",
|
||||
"CONTEXT_COMPRESSION_PROVIDER", "CONTEXT_COMPRESSION_MODEL",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
@@ -47,19 +49,35 @@ def _run_auxiliary_bridge(config_dict, monkeypatch):
|
||||
auxiliary_cfg = config_dict.get("auxiliary", {})
|
||||
if auxiliary_cfg and isinstance(auxiliary_cfg, dict):
|
||||
aux_task_env = {
|
||||
"vision": ("AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL"),
|
||||
"web_extract": ("AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL"),
|
||||
"vision": {
|
||||
"provider": "AUXILIARY_VISION_PROVIDER",
|
||||
"model": "AUXILIARY_VISION_MODEL",
|
||||
"base_url": "AUXILIARY_VISION_BASE_URL",
|
||||
"api_key": "AUXILIARY_VISION_API_KEY",
|
||||
},
|
||||
"web_extract": {
|
||||
"provider": "AUXILIARY_WEB_EXTRACT_PROVIDER",
|
||||
"model": "AUXILIARY_WEB_EXTRACT_MODEL",
|
||||
"base_url": "AUXILIARY_WEB_EXTRACT_BASE_URL",
|
||||
"api_key": "AUXILIARY_WEB_EXTRACT_API_KEY",
|
||||
},
|
||||
}
|
||||
for task_key, (prov_env, model_env) in aux_task_env.items():
|
||||
for task_key, env_map in aux_task_env.items():
|
||||
task_cfg = auxiliary_cfg.get(task_key, {})
|
||||
if not isinstance(task_cfg, dict):
|
||||
continue
|
||||
prov = str(task_cfg.get("provider", "")).strip()
|
||||
model = str(task_cfg.get("model", "")).strip()
|
||||
base_url = str(task_cfg.get("base_url", "")).strip()
|
||||
api_key = str(task_cfg.get("api_key", "")).strip()
|
||||
if prov and prov != "auto":
|
||||
os.environ[prov_env] = prov
|
||||
os.environ[env_map["provider"]] = prov
|
||||
if model:
|
||||
os.environ[model_env] = model
|
||||
os.environ[env_map["model"]] = model
|
||||
if base_url:
|
||||
os.environ[env_map["base_url"]] = base_url
|
||||
if api_key:
|
||||
os.environ[env_map["api_key"]] = api_key
|
||||
|
||||
|
||||
# ── Config bridging tests ────────────────────────────────────────────────────
|
||||
@@ -101,6 +119,21 @@ class TestAuxiliaryConfigBridge:
|
||||
assert os.environ.get("AUXILIARY_WEB_EXTRACT_PROVIDER") == "nous"
|
||||
assert os.environ.get("AUXILIARY_WEB_EXTRACT_MODEL") == "gemini-2.5-flash"
|
||||
|
||||
def test_direct_endpoint_bridged(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"vision": {
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"api_key": "local-key",
|
||||
"model": "qwen2.5-vl",
|
||||
}
|
||||
}
|
||||
}
|
||||
_run_auxiliary_bridge(config, monkeypatch)
|
||||
assert os.environ.get("AUXILIARY_VISION_BASE_URL") == "http://localhost:1234/v1"
|
||||
assert os.environ.get("AUXILIARY_VISION_API_KEY") == "local-key"
|
||||
assert os.environ.get("AUXILIARY_VISION_MODEL") == "qwen2.5-vl"
|
||||
|
||||
def test_compression_provider_bridged(self, monkeypatch):
|
||||
config = {
|
||||
"compression": {
|
||||
@@ -200,8 +233,12 @@ class TestGatewayBridgeCodeParity:
|
||||
# Check for key patterns that indicate the bridge is present
|
||||
assert "AUXILIARY_VISION_PROVIDER" in content
|
||||
assert "AUXILIARY_VISION_MODEL" in content
|
||||
assert "AUXILIARY_VISION_BASE_URL" in content
|
||||
assert "AUXILIARY_VISION_API_KEY" in content
|
||||
assert "AUXILIARY_WEB_EXTRACT_PROVIDER" in content
|
||||
assert "AUXILIARY_WEB_EXTRACT_MODEL" in content
|
||||
assert "AUXILIARY_WEB_EXTRACT_BASE_URL" in content
|
||||
assert "AUXILIARY_WEB_EXTRACT_API_KEY" in content
|
||||
|
||||
def test_gateway_has_compression_provider(self):
|
||||
"""Gateway must bridge compression.summary_provider."""
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
def _make_cli_stub():
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
cli._approval_lock = threading.Lock()
|
||||
cli._invalidate = MagicMock()
|
||||
cli._app = SimpleNamespace(invalidate=MagicMock())
|
||||
return cli
|
||||
|
||||
|
||||
class TestCliApprovalUi:
|
||||
def test_approval_callback_includes_view_for_long_commands(self):
|
||||
cli = _make_cli_stub()
|
||||
command = "sudo dd if=/tmp/githubcli-keyring.gpg of=/usr/share/keyrings/githubcli-archive-keyring.gpg bs=4M status=progress"
|
||||
result = {}
|
||||
|
||||
def _run_callback():
|
||||
result["value"] = cli._approval_callback(command, "disk copy")
|
||||
|
||||
thread = threading.Thread(target=_run_callback, daemon=True)
|
||||
thread.start()
|
||||
|
||||
deadline = time.time() + 2
|
||||
while cli._approval_state is None and time.time() < deadline:
|
||||
time.sleep(0.01)
|
||||
|
||||
assert cli._approval_state is not None
|
||||
assert "view" in cli._approval_state["choices"]
|
||||
|
||||
cli._approval_state["response_queue"].put("deny")
|
||||
thread.join(timeout=2)
|
||||
assert result["value"] == "deny"
|
||||
|
||||
def test_handle_approval_selection_view_expands_in_place(self):
|
||||
cli = _make_cli_stub()
|
||||
cli._approval_state = {
|
||||
"command": "sudo dd if=/tmp/in of=/usr/share/keyrings/githubcli-archive-keyring.gpg bs=4M status=progress",
|
||||
"description": "disk copy",
|
||||
"choices": ["once", "session", "always", "deny", "view"],
|
||||
"selected": 4,
|
||||
"response_queue": queue.Queue(),
|
||||
}
|
||||
|
||||
cli._handle_approval_selection()
|
||||
|
||||
assert cli._approval_state is not None
|
||||
assert cli._approval_state["show_full"] is True
|
||||
assert "view" not in cli._approval_state["choices"]
|
||||
assert cli._approval_state["selected"] == 3
|
||||
assert cli._approval_state["response_queue"].empty()
|
||||
|
||||
def test_approval_display_places_title_inside_box_not_border(self):
|
||||
cli = _make_cli_stub()
|
||||
cli._approval_state = {
|
||||
"command": "sudo dd if=/tmp/in of=/usr/share/keyrings/githubcli-archive-keyring.gpg bs=4M status=progress",
|
||||
"description": "disk copy",
|
||||
"choices": ["once", "session", "always", "deny", "view"],
|
||||
"selected": 0,
|
||||
"response_queue": queue.Queue(),
|
||||
}
|
||||
|
||||
fragments = cli._get_approval_display_fragments()
|
||||
rendered = "".join(text for _style, text in fragments)
|
||||
lines = rendered.splitlines()
|
||||
|
||||
assert lines[0].startswith("╭")
|
||||
assert "Dangerous Command" not in lines[0]
|
||||
assert any("Dangerous Command" in line for line in lines[1:3])
|
||||
assert "Show full command" in rendered
|
||||
assert "githubcli-archive-keyring.gpg" not in rendered
|
||||
|
||||
def test_approval_display_shows_full_command_after_view(self):
|
||||
cli = _make_cli_stub()
|
||||
full_command = "sudo dd if=/tmp/in of=/usr/share/keyrings/githubcli-archive-keyring.gpg bs=4M status=progress"
|
||||
cli._approval_state = {
|
||||
"command": full_command,
|
||||
"description": "disk copy",
|
||||
"choices": ["once", "session", "always", "deny"],
|
||||
"selected": 0,
|
||||
"show_full": True,
|
||||
"response_queue": queue.Queue(),
|
||||
}
|
||||
|
||||
fragments = cli._get_approval_display_fragments()
|
||||
rendered = "".join(text for _style, text in fragments)
|
||||
|
||||
assert "..." not in rendered
|
||||
assert "githubcli-" in rendered
|
||||
assert "archive-" in rendered
|
||||
assert "keyring.gpg" in rendered
|
||||
assert "status=progress" in rendered
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user