Compare commits

..

1 Commits

Author SHA1 Message Date
teknium1 72decda522 feat: unified streaming infrastructure (draft — awaiting streaming impl)
Unified streaming architecture combining the best of PRs #774 and #798,
with improvements. This is a draft — awaiting proper streaming token
implementation and testing before merge.

Layer 1 — Core streaming (run_agent.py):
- stream_delta_callback on AIAgent.__init__ (per-instance)
- _interruptible_streaming_api_call() for chat completions with
  SimpleNamespace response reconstruction
- Tool-call suppression (callback only fires for text-only responses)
- on_first_delta callback (stops thinking spinner on first token)
- Provider fallback when streaming unsupported
- reasoning_content accumulation
- Interrupt support (client.close() + rebuild)

Layer 2 — Display (cli.py, gateway/):
- CLI: line-buffered _stream_delta/_flush_stream via _cprint
- Gateway: async stream consumer with dual transport:
  * Draft (Bot API 9.3+ sendMessageDraft) as primary
  * Progressive editMessageText as fallback
  * Auto mode tries draft, falls back seamlessly
- Config-driven: streaming.enabled, edit_interval, buffer_threshold,
  cursor, transport (auto/draft/edit)
- Uses self.config (no duplicate yaml reads)
- already_sent flag prevents duplicate sends in base.py

Telegram-specific (gateway/platforms/telegram.py):
- send_raw / edit_message_raw (plain text, no MarkdownV2)
- send_draft / finalize_draft (Bot API 9.3+)
- delete_message
- All methods pass message_thread_id for forum topic support
  (fix for #774's missing thread_id bug)

Tests: 10 new tests covering accumulator shape, callback order,
tool-call suppression, provider fallback, already_sent contract.

Config example:
  streaming:
    enabled: true
    edit_interval: 1.0
    buffer_threshold: 100
    cursor: ' ▉'
    transport: auto  # auto, draft, or edit

Supersedes: #774 (jobless0x), #798 (OutThisLife), #697 (clicksingh)
2026-03-11 05:59:47 -07:00
8 changed files with 605 additions and 188 deletions
+1 -37
View File
@@ -5,7 +5,6 @@ Uses Gemini Flash (cheap/fast) to summarize middle turns while
protecting head and tail context.
"""
import json
import logging
import os
from typing import Any, Dict, List, Optional
@@ -83,41 +82,6 @@ class ContextCompressor:
"compression_count": self.compression_count,
}
@staticmethod
def _content_to_text(content: Any) -> str:
"""Convert message content to plain text for summarization.
Handles:
- str → returned as-is
- None → empty string
- list (multimodal) → text parts joined, images replaced with [image]
- other → JSON serialization or str() fallback
"""
if isinstance(content, str):
return content
if content is None:
return ""
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
item_type = item.get("type")
if item_type == "text":
parts.append(item.get("text", ""))
elif item_type == "image_url":
parts.append("[image]")
elif item_type:
parts.append(f"[{item_type}]")
else:
parts.append(str(item))
else:
parts.append(str(item))
return "\n".join(part for part in parts if part)
try:
return json.dumps(content, ensure_ascii=False, sort_keys=True)
except TypeError:
return str(content)
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
"""Generate a concise summary of conversation turns.
@@ -129,7 +93,7 @@ class ContextCompressor:
parts = []
for msg in turns_to_summarize:
role = msg.get("role", "unknown")
content = self._content_to_text(msg.get("content"))
content = msg.get("content") or ""
if len(content) > 2000:
content = content[:1000] + "\n...[truncated]...\n" + content[-500:]
tool_calls = msg.get("tool_calls", [])
+28 -1
View File
@@ -1253,6 +1253,7 @@ class HermesCLI:
# Background task tracking: {task_id: threading.Thread}
self._background_tasks: Dict[str, threading.Thread] = {}
self._background_task_counter = 0
self._stream_buf = ""
def _invalidate(self, min_interval: float = 0.25) -> None:
"""Throttled UI repaint — prevents terminal blinking on slow/SSH connections."""
@@ -1495,6 +1496,7 @@ class HermesCLI:
platform="cli",
session_db=self._session_db,
clarify_callback=self._clarify_callback,
stream_delta_callback=self._stream_delta,
honcho_session_key=self.session_id,
fallback_model=self._fallback_model,
thinking_callback=self._on_thinking,
@@ -3339,6 +3341,28 @@ class HermesCLI:
"Use your best judgement to make the choice and proceed."
)
_stream_started = False
def _stream_delta(self, text: str):
"""Buffer streaming tokens; emit complete lines via _cprint."""
if not text:
return
if not self._stream_started:
text = text.lstrip("\n")
if not text:
return
self._stream_started = True
self._stream_buf += text
while "\n" in self._stream_buf:
line, self._stream_buf = self._stream_buf.split("\n", 1)
_cprint(line)
def _flush_stream(self):
"""Emit any remaining partial line from the stream buffer."""
if self._stream_buf:
_cprint(self._stream_buf)
self._stream_buf = ""
def _sudo_password_callback(self) -> str:
"""
Prompt for sudo password through the prompt_toolkit UI.
@@ -3467,6 +3491,8 @@ class HermesCLI:
# Add user message to history
self.conversation_history.append({"role": "user", "content": message})
self._stream_buf = ""
self._stream_started = False
_cprint(f"{_GOLD}{'' * 40}{_RST}")
print(flush=True)
@@ -3514,6 +3540,7 @@ class HermesCLI:
agent_thread.join(0.1)
agent_thread.join() # Ensure agent thread completes
self._flush_stream()
# Drain any remaining agent output still in the StdoutProxy
# buffer so tool/status lines render ABOVE our response box.
@@ -3542,7 +3569,7 @@ class HermesCLI:
if response and pending_message:
response = response + "\n\n---\n_[Interrupted - processing new message]_"
if response:
if response and not (self.agent and self.agent.stream_delta_callback):
# Use a Rich Panel for the response box — adapts to terminal
# width at render time instead of hard-coding border length.
try:
+42 -3
View File
@@ -413,6 +413,36 @@ class BasePlatformAdapter(ABC):
"""
return SendResult(success=False, error="Not supported")
@property
def supports_streaming(self) -> bool:
"""Whether this platform supports response streaming via message edits."""
return False
@property
def supports_draft_streaming(self) -> bool:
"""Whether this platform supports native draft streaming (Bot API 9.3+)."""
return False
async def send_draft(self, chat_id: str, draft_id: int, text: str, metadata: dict = None) -> bool:
"""Push a draft text update. Override in subclasses."""
return False
async def finalize_draft(self, chat_id: str, content: str, metadata: dict = None) -> "SendResult":
"""Finalize a draft stream with the completed message."""
return SendResult(success=False, error="Not supported")
async def delete_message(self, chat_id: str, message_id: str) -> SendResult:
"""Delete a previously sent message."""
return SendResult(success=False, error="Not supported")
async def send_raw(self, chat_id: str, content: str, metadata: dict = None) -> "SendResult":
"""Send without formatting (default: delegates to send)."""
return await self.send(chat_id=chat_id, content=content, metadata=metadata)
async def edit_message_raw(self, chat_id: str, message_id: str, content: str) -> "SendResult":
"""Edit without formatting (default: delegates to edit_message)."""
return await self.edit_message(chat_id=chat_id, message_id=message_id, content=content)
async def send_typing(self, chat_id: str, metadata=None) -> None:
"""
Send a typing indicator.
@@ -697,11 +727,20 @@ class BasePlatformAdapter(ABC):
try:
# Call the handler (this can take a while with tool calls)
response = await self._message_handler(event)
handler_result = await self._message_handler(event)
# Normalise: handler may return str or dict(content, already_sent)
already_sent = False
if isinstance(handler_result, dict):
response = handler_result.get("content") or ""
already_sent = handler_result.get("already_sent", False)
else:
response = handler_result
# Send response if any
if not response:
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
if not already_sent:
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
if response:
# Extract MEDIA:<path> tags (from TTS tool) before other processing
media_files, response = self.extract_media(response)
@@ -712,7 +751,7 @@ class BasePlatformAdapter(ABC):
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
# Send the text portion first (if any remains after extractions)
if text_content:
if text_content and not already_sent:
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
result = await self.send(
chat_id=event.source.chat_id,
+28 -80
View File
@@ -9,7 +9,6 @@ Uses slack-bolt (Python) with Socket Mode for:
"""
import asyncio
import logging
import os
import re
from typing import Dict, List, Optional, Any
@@ -42,9 +41,6 @@ from gateway.platforms.base import (
)
logger = logging.getLogger(__name__)
def check_slack_requirements() -> bool:
"""Check if Slack dependencies are available."""
return SLACK_AVAILABLE
@@ -77,19 +73,17 @@ class SlackAdapter(BasePlatformAdapter):
async def connect(self) -> bool:
"""Connect to Slack via Socket Mode."""
if not SLACK_AVAILABLE:
logger.error(
"[Slack] slack-bolt not installed. Run: pip install slack-bolt",
)
print("[Slack] slack-bolt not installed. Run: pip install slack-bolt")
return False
bot_token = self.config.token
app_token = os.getenv("SLACK_APP_TOKEN")
if not bot_token:
logger.error("[Slack] SLACK_BOT_TOKEN not set")
print("[Slack] SLACK_BOT_TOKEN not set")
return False
if not app_token:
logger.error("[Slack] SLACK_APP_TOKEN not set")
print("[Slack] SLACK_APP_TOKEN not set")
return False
try:
@@ -123,22 +117,19 @@ class SlackAdapter(BasePlatformAdapter):
asyncio.create_task(self._handler.start_async())
self._running = True
logger.info("[Slack] Connected as @%s (Socket Mode)", bot_name)
print(f"[Slack] Connected as @{bot_name} (Socket Mode)")
return True
except Exception as e: # pragma: no cover - defensive logging
logger.error("[Slack] Connection failed: %s", e, exc_info=True)
except Exception as e:
print(f"[Slack] Connection failed: {e}")
return False
async def disconnect(self) -> None:
"""Disconnect from Slack."""
if self._handler:
try:
await self._handler.close_async()
except Exception as e: # pragma: no cover - defensive logging
logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True)
await self._handler.close_async()
self._running = False
logger.info("[Slack] Disconnected")
print("[Slack] Disconnected")
async def send(
self,
@@ -171,8 +162,8 @@ class SlackAdapter(BasePlatformAdapter):
raw_response=result,
)
except Exception as e: # pragma: no cover - defensive logging
logger.error("[Slack] Send error: %s", e, exc_info=True)
except Exception as e:
print(f"[Slack] Send error: {e}")
return SendResult(success=False, error=str(e))
async def edit_message(
@@ -191,14 +182,7 @@ class SlackAdapter(BasePlatformAdapter):
text=content,
)
return SendResult(success=True, message_id=message_id)
except Exception as e: # pragma: no cover - defensive logging
logger.error(
"[Slack] Failed to edit message %s in channel %s: %s",
message_id,
chat_id,
e,
exc_info=True,
)
except Exception as e:
return SendResult(success=False, error=str(e))
async def send_typing(self, chat_id: str, metadata=None) -> None:
@@ -230,14 +214,8 @@ class SlackAdapter(BasePlatformAdapter):
)
return SendResult(success=True, raw_response=result)
except Exception as e: # pragma: no cover - defensive logging
logger.error(
"[%s] Failed to send local Slack image %s: %s",
self.name,
image_path,
e,
exc_info=True,
)
except Exception as e:
print(f"[{self.name}] Failed to send local image: {e}")
return await super().send_image_file(chat_id, image_path, caption, reply_to)
async def send_image(
@@ -269,13 +247,7 @@ class SlackAdapter(BasePlatformAdapter):
return SendResult(success=True, raw_response=result)
except Exception as e: # pragma: no cover - defensive logging
logger.warning(
"[Slack] Failed to upload image from URL %s, falling back to text: %s",
image_url,
e,
exc_info=True,
)
except Exception as e:
# Fall back to sending the URL as text
text = f"{caption}\n{image_url}" if caption else image_url
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
@@ -301,13 +273,7 @@ class SlackAdapter(BasePlatformAdapter):
)
return SendResult(success=True, raw_response=result)
except Exception as e: # pragma: no cover - defensive logging
logger.error(
"[Slack] Failed to send audio file %s: %s",
audio_path,
e,
exc_info=True,
)
except Exception as e:
return SendResult(success=False, error=str(e))
async def send_video(
@@ -334,14 +300,8 @@ class SlackAdapter(BasePlatformAdapter):
)
return SendResult(success=True, raw_response=result)
except Exception as e: # pragma: no cover - defensive logging
logger.error(
"[%s] Failed to send video %s: %s",
self.name,
video_path,
e,
exc_info=True,
)
except Exception as e:
print(f"[{self.name}] Failed to send video: {e}")
return await super().send_video(chat_id, video_path, caption, reply_to)
async def send_document(
@@ -371,14 +331,8 @@ class SlackAdapter(BasePlatformAdapter):
)
return SendResult(success=True, raw_response=result)
except Exception as e: # pragma: no cover - defensive logging
logger.error(
"[%s] Failed to send document %s: %s",
self.name,
file_path,
e,
exc_info=True,
)
except Exception as e:
print(f"[{self.name}] Failed to send document: {e}")
return await super().send_document(chat_id, file_path, caption, file_name, reply_to)
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
@@ -394,13 +348,7 @@ class SlackAdapter(BasePlatformAdapter):
"name": channel.get("name", chat_id),
"type": "dm" if is_dm else "group",
}
except Exception as e: # pragma: no cover - defensive logging
logger.error(
"[Slack] Failed to fetch chat info for %s: %s",
chat_id,
e,
exc_info=True,
)
except Exception:
return {"name": chat_id, "type": "unknown"}
# ----- Internal handlers -----
@@ -455,8 +403,8 @@ class SlackAdapter(BasePlatformAdapter):
media_urls.append(cached)
media_types.append(mimetype)
msg_type = MessageType.PHOTO
except Exception as e: # pragma: no cover - defensive logging
logger.warning("[Slack] Failed to cache image from %s: %s", url, e, exc_info=True)
except Exception as e:
print(f"[Slack] Failed to cache image: {e}", flush=True)
elif mimetype.startswith("audio/") and url:
try:
ext = "." + mimetype.split("/")[-1].split(";")[0]
@@ -466,8 +414,8 @@ class SlackAdapter(BasePlatformAdapter):
media_urls.append(cached)
media_types.append(mimetype)
msg_type = MessageType.VOICE
except Exception as e: # pragma: no cover - defensive logging
logger.warning("[Slack] Failed to cache audio from %s: %s", url, e, exc_info=True)
except Exception as e:
print(f"[Slack] Failed to cache audio: {e}", flush=True)
elif url:
# Try to handle as a document attachment
try:
@@ -489,7 +437,7 @@ class SlackAdapter(BasePlatformAdapter):
file_size = f.get("size", 0)
MAX_DOC_BYTES = 20 * 1024 * 1024
if not file_size or file_size > MAX_DOC_BYTES:
logger.warning("[Slack] Document too large or unknown size: %s", file_size)
print(f"[Slack] Document too large or unknown size: {file_size}", flush=True)
continue
# Download and cache
@@ -501,7 +449,7 @@ class SlackAdapter(BasePlatformAdapter):
media_urls.append(cached_path)
media_types.append(doc_mime)
msg_type = MessageType.DOCUMENT
logger.debug("[Slack] Cached user document: %s", cached_path)
print(f"[Slack] Cached user document: {cached_path}", flush=True)
# Inject text content for .txt/.md files (capped at 100 KB)
MAX_TEXT_INJECT_BYTES = 100 * 1024
@@ -518,8 +466,8 @@ class SlackAdapter(BasePlatformAdapter):
except UnicodeDecodeError:
pass # Binary content, skip injection
except Exception as e: # pragma: no cover - defensive logging
logger.warning("[Slack] Failed to cache document from %s: %s", url, e, exc_info=True)
except Exception as e:
print(f"[Slack] Failed to cache document: {e}", flush=True)
# Build source
source = self.build_source(
+93
View File
@@ -299,6 +299,99 @@ class TelegramAdapter(BasePlatformAdapter):
)
return SendResult(success=False, error=str(e))
async def send_raw(
self, chat_id: str, content: str, metadata: dict = None,
) -> SendResult:
"""Send a plain-text message without MarkdownV2 formatting."""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
thread_id = metadata.get("thread_id") if metadata else None
msg = await self._bot.send_message(
chat_id=int(chat_id), text=content, parse_mode=None,
message_thread_id=int(thread_id) if thread_id else None,
)
return SendResult(success=True, message_id=str(msg.message_id))
except Exception as e:
return SendResult(success=False, error=str(e))
async def edit_message_raw(
self, chat_id: str, message_id: str, content: str,
) -> SendResult:
"""Edit a message with plain text (no MarkdownV2 formatting)."""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
await self._bot.edit_message_text(
chat_id=int(chat_id), message_id=int(message_id),
text=content, parse_mode=None,
)
return SendResult(success=True, message_id=message_id)
except Exception as e:
return SendResult(success=False, error=str(e))
@property
def supports_streaming(self) -> bool:
return True
@property
def supports_draft_streaming(self) -> bool:
"""Whether this adapter supports Telegram Bot API sendMessageDraft (9.3+)."""
return True
async def send_draft(
self, chat_id: str, draft_id: int, text: str, metadata: dict = None,
) -> bool:
"""Push a draft update via sendMessageDraft (Bot API 9.3+)."""
if not self._bot:
return False
try:
thread_id = metadata.get("thread_id") if metadata else None
return await self._bot.send_message_draft(
chat_id=int(chat_id), draft_id=draft_id, text=text,
parse_mode=None,
message_thread_id=int(thread_id) if thread_id else None,
)
except Exception as e:
logger.warning("[%s] send_message_draft failed: %s", self.name, e)
return False
async def finalize_draft(
self, chat_id: str, content: str, metadata: dict = None,
) -> SendResult:
"""Finalize a draft stream by sending the completed message with formatting."""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
thread_id = metadata.get("thread_id") if metadata else None
formatted = self.format_message(content)
try:
msg = await self._bot.send_message(
chat_id=int(chat_id), text=formatted,
parse_mode=ParseMode.MARKDOWN_V2,
message_thread_id=int(thread_id) if thread_id else None,
)
except Exception:
msg = await self._bot.send_message(
chat_id=int(chat_id), text=content, parse_mode=None,
message_thread_id=int(thread_id) if thread_id else None,
)
return SendResult(success=True, message_id=str(msg.message_id))
except Exception as e:
return SendResult(success=False, error=str(e))
async def delete_message(self, chat_id: str, message_id: str) -> SendResult:
"""Delete a Telegram message."""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
await self._bot.delete_message(
chat_id=int(chat_id), message_id=int(message_id),
)
return SendResult(success=True, message_id=message_id)
except Exception as e:
return SendResult(success=False, error=str(e))
async def send_voice(
self,
chat_id: str,
+156 -3
View File
@@ -175,6 +175,7 @@ class AIAgent:
thinking_callback: callable = None,
clarify_callback: callable = None,
step_callback: callable = None,
stream_delta_callback: callable = None,
max_tokens: int = None,
reasoning_config: Dict[str, Any] = None,
prefill_messages: List[Dict[str, Any]] = None,
@@ -262,6 +263,7 @@ class AIAgent:
self.thinking_callback = thinking_callback
self.clarify_callback = clarify_callback
self.step_callback = step_callback
self.stream_delta_callback = stream_delta_callback
self._last_reported_tool = None # Track for "new tool" mode
# Interrupt mechanism for breaking out of tool loops
@@ -2060,6 +2062,147 @@ class AIAgent:
return terminal_response
raise RuntimeError("Responses create(stream=True) fallback did not emit a terminal response.")
def _interruptible_streaming_api_call(self, api_kwargs: dict, on_first_delta=None):
"""Streaming variant of _interruptible_api_call for chat_completions.
Fires self.stream_delta_callback(text) as content tokens arrive and
accumulates the full response into a SimpleNamespace matching the shape
downstream code expects. Falls back to the non-streaming path when the
provider rejects the stream request.
"""
from types import SimpleNamespace
result = {"response": None, "error": None}
first_delta_fired = [False]
def _stream():
try:
stream_kwargs = {**api_kwargs, "stream": True,
"stream_options": {"include_usage": True}}
stream_resp = self.client.chat.completions.create(**stream_kwargs)
content_parts = []
tool_calls_acc = {}
finish_reason = "stop"
usage = None
reasoning_content = None
model = None
has_tool_calls = False
try:
for chunk in stream_resp:
if not chunk.choices:
if hasattr(chunk, "usage") and chunk.usage:
usage = chunk.usage
continue
choice = chunk.choices[0]
if choice.finish_reason:
finish_reason = choice.finish_reason
if model is None and hasattr(chunk, "model"):
model = chunk.model
delta = choice.delta
if delta is None:
continue
if delta.content:
content_parts.append(delta.content)
if not first_delta_fired[0]:
first_delta_fired[0] = True
if on_first_delta:
on_first_delta()
if self.stream_delta_callback and not has_tool_calls:
try:
self.stream_delta_callback(delta.content)
except Exception:
pass
if delta.tool_calls:
has_tool_calls = True
for tc_delta in delta.tool_calls:
idx = tc_delta.index
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {
"id": tc_delta.id or "",
"type": tc_delta.type or "function",
"function": {
"name": getattr(tc_delta.function, "name", None) or "",
"arguments": getattr(tc_delta.function, "arguments", None) or "",
},
}
else:
entry = tool_calls_acc[idx]
if tc_delta.id:
entry["id"] = tc_delta.id
fn = tc_delta.function
if fn:
if fn.name:
entry["function"]["name"] = fn.name
if fn.arguments:
entry["function"]["arguments"] += fn.arguments
rc = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
if rc:
reasoning_content = (reasoning_content or "") + rc
finally:
close_fn = getattr(stream_resp, "close", None)
if callable(close_fn):
try:
close_fn()
except Exception:
pass
tool_calls_list = None
if tool_calls_acc:
tool_calls_list = [
SimpleNamespace(
id=tc["id"], call_id=tc["id"], type=tc["type"],
function=SimpleNamespace(name=tc["function"]["name"],
arguments=tc["function"]["arguments"]),
)
for idx, tc in sorted(tool_calls_acc.items())
]
message = SimpleNamespace(
content="".join(content_parts) or None,
tool_calls=tool_calls_list,
reasoning=reasoning_content,
reasoning_content=reasoning_content,
reasoning_details=None,
)
result["response"] = SimpleNamespace(
choices=[SimpleNamespace(message=message, finish_reason=finish_reason)],
usage=usage,
model=model,
)
except Exception as e:
result["error"] = e
t = threading.Thread(target=_stream, daemon=True)
t.start()
while t.is_alive():
t.join(timeout=0.3)
if self._interrupt_requested:
try:
self.client.close()
except Exception:
pass
try:
self.client = OpenAI(**self._client_kwargs)
except Exception:
pass
raise InterruptedError("Agent interrupted during streaming API call")
if result["error"] is not None:
err = result["error"]
err_str = str(err).lower()
if any(kw in err_str for kw in ("stream", "not support", "unsupported")):
logger.debug("Streaming failed (%s), falling back to non-streaming.", err)
return self._interruptible_api_call(api_kwargs)
raise err
return result["response"]
def _try_refresh_codex_client_credentials(self, *, force: bool = True) -> bool:
if self.api_mode != "codex_responses" or self.provider != "openai-codex":
return False
@@ -3474,7 +3617,17 @@ class AIAgent:
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
self._dump_api_request_debug(api_kwargs, reason="preflight")
response = self._interruptible_api_call(api_kwargs)
if self.stream_delta_callback and self.api_mode != "codex_responses":
def _stop_spinner():
nonlocal thinking_spinner
if thinking_spinner:
thinking_spinner.stop("")
thinking_spinner = None
response = self._interruptible_streaming_api_call(
api_kwargs, on_first_delta=_stop_spinner)
else:
response = self._interruptible_api_call(api_kwargs)
api_duration = time.time() - api_start_time
@@ -4230,8 +4383,8 @@ class AIAgent:
turn_content = assistant_message.content or ""
if turn_content and self._has_content_after_think_block(turn_content):
self._last_content_with_tools = turn_content
# Show intermediate commentary so the user can follow along
if self.quiet_mode:
# Show intermediate commentary — skip when streaming (already in buffer)
if self.quiet_mode and not self.stream_delta_callback:
clean = self._strip_think_blocks(turn_content).strip()
if clean:
print(f" ┊ 💬 {clean}")
-64
View File
@@ -115,70 +115,6 @@ class TestCompress:
assert result[-2]["content"] == msgs[-2]["content"]
class TestContentToText:
"""Test _content_to_text handles all content types without crashing."""
def test_string_passthrough(self, compressor):
assert compressor._content_to_text("hello") == "hello"
def test_none_returns_empty(self, compressor):
assert compressor._content_to_text(None) == ""
def test_multimodal_text_parts(self, compressor):
content = [
{"type": "text", "text": "describe this image"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}},
]
result = compressor._content_to_text(content)
assert "describe this image" in result
assert "[image]" in result
def test_multimodal_mixed_types(self, compressor):
content = [
{"type": "text", "text": "first part"},
{"type": "audio", "audio": {"data": "..."}},
{"type": "text", "text": "second part"},
]
result = compressor._content_to_text(content)
assert "first part" in result
assert "[audio]" in result
assert "second part" in result
def test_dict_content_json_serialized(self, compressor):
content = {"key": "value"}
result = compressor._content_to_text(content)
assert "key" in result
assert "value" in result
def test_multimodal_in_generate_summary(self):
"""Multimodal user messages should not crash _generate_summary."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: image was discussed"
mock_client.chat.completions.create.return_value = mock_response
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
c = ContextCompressor(model="test", quiet_mode=True)
messages = [
{"role": "user", "content": [
{"type": "text", "text": "What is in this image?"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}},
]},
{"role": "assistant", "content": "I see a cat."},
{"role": "user", "content": "thanks"},
]
summary = c._generate_summary(messages)
assert isinstance(summary, str)
# The prompt sent to the model should contain the text, not raw list
prompt = mock_client.chat.completions.create.call_args.kwargs["messages"][0]["content"]
assert "What is in this image?" in prompt
assert "[image]" in prompt
class TestGenerateSummaryNoneContent:
"""Regression: content=None (from tool-call-only assistant messages) must not crash."""
+257
View File
@@ -0,0 +1,257 @@
"""Tests for streaming token output — accumulator shape, callback order, fallback."""
import queue
import threading
from types import SimpleNamespace
from unittest.mock import MagicMock, patch, call
import pytest
from run_agent import AIAgent
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_tool_defs(*names):
return [
{"type": "function", "function": {"name": n, "description": f"{n}", "parameters": {"type": "object", "properties": {}}}}
for n in names
]
@pytest.fixture()
def agent():
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
cb = MagicMock()
a = AIAgent(
api_key="test-key-1234567890",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
stream_delta_callback=cb,
)
a.client = MagicMock()
a._stream_cb = cb
return a
# ---------------------------------------------------------------------------
# Helpers — fake streaming chunks
# ---------------------------------------------------------------------------
def _chunk(content=None, tool_call_delta=None, finish_reason=None, usage=None, model=None):
delta = SimpleNamespace(content=content, tool_calls=tool_call_delta,
reasoning_content=None, reasoning=None)
choice = SimpleNamespace(delta=delta, finish_reason=finish_reason)
c = SimpleNamespace(choices=[choice])
if usage is not None:
c.usage = SimpleNamespace(**usage)
if model:
c.model = model
return c
def _usage_chunk(**kw):
c = SimpleNamespace(choices=[], usage=SimpleNamespace(**kw))
return c
def _tc_delta(index, id=None, name=None, arguments=None, type=None):
fn = SimpleNamespace(name=name, arguments=arguments)
return SimpleNamespace(index=index, id=id, type=type, function=fn)
# ---------------------------------------------------------------------------
# Tests: accumulator shape
# ---------------------------------------------------------------------------
class TestStreamingAccumulator:
def test_text_only_response(self, agent):
"""Streaming text-only response produces correct synthetic shape."""
chunks = [
_chunk(content="Hello", model="test/m"),
_chunk(content=" world"),
_chunk(finish_reason="stop"),
_usage_chunk(prompt_tokens=10, completion_tokens=5, total_tokens=15),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._interruptible_streaming_api_call({"model": "test"})
assert resp.choices[0].message.content == "Hello world"
assert resp.choices[0].message.tool_calls is None
assert resp.choices[0].finish_reason == "stop"
assert resp.usage.prompt_tokens == 10
assert resp.model == "test/m"
def test_tool_call_response(self, agent):
"""Streaming tool-call response accumulates function name + arguments."""
chunks = [
_chunk(tool_call_delta=[_tc_delta(0, id="call_1", name="web_search", arguments='{"q', type="function")]),
_chunk(tool_call_delta=[_tc_delta(0, arguments='uery": "hi"}')]),
_chunk(finish_reason="tool_calls"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._interruptible_streaming_api_call({"model": "test"})
tc = resp.choices[0].message.tool_calls
assert tc is not None
assert len(tc) == 1
assert tc[0].id == "call_1"
assert tc[0].function.name == "web_search"
assert tc[0].function.arguments == '{"query": "hi"}'
assert resp.choices[0].finish_reason == "tool_calls"
def test_mixed_content_and_tool_calls(self, agent):
"""Content + tool calls in same stream are both accumulated."""
chunks = [
_chunk(content="Let me check."),
_chunk(tool_call_delta=[_tc_delta(0, id="c1", name="web_search", arguments="{}", type="function")]),
_chunk(finish_reason="tool_calls"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._interruptible_streaming_api_call({"model": "test"})
assert resp.choices[0].message.content == "Let me check."
assert len(resp.choices[0].message.tool_calls) == 1
class TestStreamingCallbacks:
def test_deltas_fire_in_order(self, agent):
"""stream_delta_callback receives content deltas in order."""
received = []
agent.stream_delta_callback = lambda t: received.append(t)
chunks = [_chunk(content="a"), _chunk(content="b"), _chunk(content="c"), _chunk(finish_reason="stop")]
agent.client.chat.completions.create.return_value = iter(chunks)
agent._interruptible_streaming_api_call({"model": "test"})
assert received == ["a", "b", "c"]
def test_on_first_delta_fires_once(self, agent):
first = MagicMock()
chunks = [_chunk(content="x"), _chunk(content="y"), _chunk(finish_reason="stop")]
agent.client.chat.completions.create.return_value = iter(chunks)
agent._interruptible_streaming_api_call({"model": "test"}, on_first_delta=first)
first.assert_called_once()
def test_tool_only_does_not_fire_callback(self, agent):
"""Tool-call-only stream does not invoke stream_delta_callback."""
received = []
agent.stream_delta_callback = lambda t: received.append(t)
chunks = [
_chunk(tool_call_delta=[_tc_delta(0, id="c1", name="t", arguments="{}", type="function")]),
_chunk(finish_reason="tool_calls"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
agent._interruptible_streaming_api_call({"model": "test"})
assert received == []
class TestStreamingFallback:
def test_stream_error_falls_back(self, agent):
"""When streaming fails with 'not support', falls back to non-streaming."""
agent.client.chat.completions.create.side_effect = [
Exception("streaming not supported by this provider"),
SimpleNamespace(
choices=[SimpleNamespace(
message=SimpleNamespace(content="ok", tool_calls=None, reasoning=None, reasoning_content=None, reasoning_details=None),
finish_reason="stop",
)],
usage=None,
model="test/m",
),
]
resp = agent._interruptible_streaming_api_call({"model": "test"})
assert resp.choices[0].message.content == "ok"
assert agent.client.chat.completions.create.call_count == 2
def test_non_stream_error_raises(self, agent):
"""Non-stream-related errors propagate normally."""
agent.client.chat.completions.create.side_effect = ValueError("bad request")
with pytest.raises(ValueError, match="bad request"):
agent._interruptible_streaming_api_call({"model": "test"})
# ---------------------------------------------------------------------------
# Tests: base.py already_sent contract
# ---------------------------------------------------------------------------
class TestAlreadySentContract:
def _make_adapter(self, send_side_effect=None):
from gateway.platforms.base import BasePlatformAdapter, SendResult
from gateway.config import Platform, PlatformConfig
class FakeAdapter(BasePlatformAdapter):
async def connect(self): return True
async def disconnect(self): pass
async def get_chat_info(self, chat_id): return {"name": "test"}
async def send(self, chat_id, content, reply_to=None, metadata=None):
if send_side_effect is not None:
send_side_effect(content)
return SendResult(success=True, message_id="1")
cfg = PlatformConfig(enabled=True)
adapter = FakeAdapter(cfg, Platform.TELEGRAM)
adapter._running = True
return adapter
@pytest.mark.asyncio
async def test_already_sent_skips_send(self):
"""Handler returning already_sent=True prevents base from calling send()."""
from gateway.platforms.base import MessageEvent
from gateway.config import Platform
from gateway.session import SessionSource
sent = []
adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c))
async def handler(event):
return {"content": "hello", "already_sent": True}
adapter.set_message_handler(handler)
event = MessageEvent(
text="hi",
source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"),
)
await adapter._process_message_background(event, "s1")
assert sent == [], "send() should not be called when already_sent=True"
@pytest.mark.asyncio
async def test_string_response_sends_normally(self):
"""Handler returning a plain string triggers send() as before."""
from gateway.platforms.base import MessageEvent
from gateway.config import Platform
from gateway.session import SessionSource
sent = []
adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c))
async def handler(event):
return "hello"
adapter.set_message_handler(handler)
event = MessageEvent(
text="hi",
source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"),
)
await adapter._process_message_background(event, "s1")
assert "hello" in sent