Compare commits
1 Commits
feat/head-
...
feat/disco
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8240143b6 |
@@ -120,9 +120,23 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Ignore bot's own messages
|
||||
# Always ignore our own messages
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
|
||||
# Bot message filtering (DISCORD_ALLOW_BOTS):
|
||||
# "none" — ignore all other bots (default)
|
||||
# "mentions" — accept bot messages only when they @mention us
|
||||
# "all" — accept all bot messages
|
||||
if getattr(message.author, "bot", False):
|
||||
allow_bots = os.getenv("DISCORD_ALLOW_BOTS", "none").lower().strip()
|
||||
if allow_bots == "none":
|
||||
return
|
||||
elif allow_bots == "mentions":
|
||||
if not self._client.user or self._client.user not in message.mentions:
|
||||
return
|
||||
# "all" falls through to handle_message
|
||||
|
||||
await self._handle_message(message)
|
||||
|
||||
# Register slash commands
|
||||
|
||||
117
tests/gateway/test_discord_bot_filter.py
Normal file
117
tests/gateway/test_discord_bot_filter.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Tests for Discord bot message filtering (DISCORD_ALLOW_BOTS)."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
def _make_author(*, bot: bool = False, is_self: bool = False):
|
||||
"""Create a mock Discord author."""
|
||||
author = MagicMock()
|
||||
author.bot = bot
|
||||
author.id = 99999 if is_self else 12345
|
||||
author.name = "TestBot" if bot else "TestUser"
|
||||
author.display_name = author.name
|
||||
return author
|
||||
|
||||
|
||||
def _make_message(*, author=None, content="hello", mentions=None, is_dm=False):
|
||||
"""Create a mock Discord message."""
|
||||
msg = MagicMock()
|
||||
msg.author = author or _make_author()
|
||||
msg.content = content
|
||||
msg.attachments = []
|
||||
msg.mentions = mentions or []
|
||||
if is_dm:
|
||||
import discord
|
||||
msg.channel = MagicMock(spec=discord.DMChannel)
|
||||
msg.channel.id = 111
|
||||
else:
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.id = 222
|
||||
msg.channel.name = "test-channel"
|
||||
msg.channel.guild = MagicMock()
|
||||
msg.channel.guild.name = "TestServer"
|
||||
# Make isinstance checks fail for DMChannel and Thread
|
||||
type(msg.channel).__name__ = "TextChannel"
|
||||
return msg
|
||||
|
||||
|
||||
class TestDiscordBotFilter(unittest.TestCase):
|
||||
"""Test the DISCORD_ALLOW_BOTS filtering logic."""
|
||||
|
||||
def _run_filter(self, message, allow_bots="none", client_user=None):
|
||||
"""Simulate the on_message filter logic and return whether message was accepted."""
|
||||
# Replicate the exact filter logic from discord.py on_message
|
||||
if message.author == client_user:
|
||||
return False # own messages always ignored
|
||||
|
||||
if getattr(message.author, "bot", False):
|
||||
allow = allow_bots.lower().strip()
|
||||
if allow == "none":
|
||||
return False
|
||||
elif allow == "mentions":
|
||||
if not client_user or client_user not in message.mentions:
|
||||
return False
|
||||
# "all" falls through
|
||||
|
||||
return True # message accepted
|
||||
|
||||
def test_own_messages_always_ignored(self):
|
||||
"""Bot's own messages are always ignored regardless of allow_bots."""
|
||||
bot_user = _make_author(is_self=True)
|
||||
msg = _make_message(author=bot_user)
|
||||
self.assertFalse(self._run_filter(msg, "all", bot_user))
|
||||
|
||||
def test_human_messages_always_accepted(self):
|
||||
"""Human messages are always accepted regardless of allow_bots."""
|
||||
human = _make_author(bot=False)
|
||||
msg = _make_message(author=human)
|
||||
self.assertTrue(self._run_filter(msg, "none"))
|
||||
self.assertTrue(self._run_filter(msg, "mentions"))
|
||||
self.assertTrue(self._run_filter(msg, "all"))
|
||||
|
||||
def test_allow_bots_none_rejects_bots(self):
|
||||
"""With allow_bots=none, all other bot messages are rejected."""
|
||||
bot = _make_author(bot=True)
|
||||
msg = _make_message(author=bot)
|
||||
self.assertFalse(self._run_filter(msg, "none"))
|
||||
|
||||
def test_allow_bots_all_accepts_bots(self):
|
||||
"""With allow_bots=all, all bot messages are accepted."""
|
||||
bot = _make_author(bot=True)
|
||||
msg = _make_message(author=bot)
|
||||
self.assertTrue(self._run_filter(msg, "all"))
|
||||
|
||||
def test_allow_bots_mentions_rejects_without_mention(self):
|
||||
"""With allow_bots=mentions, bot messages without @mention are rejected."""
|
||||
our_user = _make_author(is_self=True)
|
||||
bot = _make_author(bot=True)
|
||||
msg = _make_message(author=bot, mentions=[])
|
||||
self.assertFalse(self._run_filter(msg, "mentions", our_user))
|
||||
|
||||
def test_allow_bots_mentions_accepts_with_mention(self):
|
||||
"""With allow_bots=mentions, bot messages with @mention are accepted."""
|
||||
our_user = _make_author(is_self=True)
|
||||
bot = _make_author(bot=True)
|
||||
msg = _make_message(author=bot, mentions=[our_user])
|
||||
self.assertTrue(self._run_filter(msg, "mentions", our_user))
|
||||
|
||||
def test_default_is_none(self):
|
||||
"""Default behavior (no env var) should be 'none'."""
|
||||
default = os.getenv("DISCORD_ALLOW_BOTS", "none")
|
||||
self.assertEqual(default, "none")
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""Allow_bots value should be case-insensitive."""
|
||||
bot = _make_author(bot=True)
|
||||
msg = _make_message(author=bot)
|
||||
self.assertTrue(self._run_filter(msg, "ALL"))
|
||||
self.assertTrue(self._run_filter(msg, "All"))
|
||||
self.assertFalse(self._run_filter(msg, "NONE"))
|
||||
self.assertFalse(self._run_filter(msg, "None"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -393,56 +393,5 @@ class TestStubSchemaDrift(unittest.TestCase):
|
||||
self.assertIn("mode", src)
|
||||
|
||||
|
||||
class TestHeadTailTruncation(unittest.TestCase):
|
||||
"""Tests for head+tail truncation of large stdout in execute_code."""
|
||||
|
||||
def _run(self, code):
|
||||
with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call):
|
||||
result = execute_code(
|
||||
code=code,
|
||||
task_id="test-task",
|
||||
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
|
||||
)
|
||||
return json.loads(result)
|
||||
|
||||
def test_short_output_not_truncated(self):
|
||||
"""Output under MAX_STDOUT_BYTES should not be truncated."""
|
||||
result = self._run('print("small output")')
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertIn("small output", result["output"])
|
||||
self.assertNotIn("TRUNCATED", result["output"])
|
||||
|
||||
def test_large_output_preserves_head_and_tail(self):
|
||||
"""Output exceeding MAX_STDOUT_BYTES keeps both head and tail."""
|
||||
code = '''
|
||||
# Print HEAD marker, then filler, then TAIL marker
|
||||
print("HEAD_MARKER_START")
|
||||
for i in range(15000):
|
||||
print(f"filler_line_{i:06d}_padding_to_fill_buffer")
|
||||
print("TAIL_MARKER_END")
|
||||
'''
|
||||
result = self._run(code)
|
||||
self.assertEqual(result["status"], "success")
|
||||
output = result["output"]
|
||||
# Head should be preserved
|
||||
self.assertIn("HEAD_MARKER_START", output)
|
||||
# Tail should be preserved (this is the key improvement)
|
||||
self.assertIn("TAIL_MARKER_END", output)
|
||||
# Truncation notice should be present
|
||||
self.assertIn("TRUNCATED", output)
|
||||
|
||||
def test_truncation_notice_format(self):
|
||||
"""Truncation notice includes character counts."""
|
||||
code = '''
|
||||
for i in range(15000):
|
||||
print(f"padding_line_{i:06d}_xxxxxxxxxxxxxxxxxxxxxxxxxx")
|
||||
'''
|
||||
result = self._run(code)
|
||||
output = result["output"]
|
||||
if "TRUNCATED" in output:
|
||||
self.assertIn("chars omitted", output)
|
||||
self.assertIn("total", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -457,17 +457,11 @@ def execute_code(
|
||||
|
||||
# --- Poll loop: watch for exit, timeout, and interrupt ---
|
||||
deadline = time.monotonic() + timeout
|
||||
stdout_chunks: list = []
|
||||
stderr_chunks: list = []
|
||||
|
||||
# Background readers to avoid pipe buffer deadlocks.
|
||||
# For stdout we use a head+tail strategy: keep the first HEAD_BYTES
|
||||
# and a rolling window of the last TAIL_BYTES so the final print()
|
||||
# output is never lost. Stderr keeps head-only (errors appear early).
|
||||
_STDOUT_HEAD_BYTES = int(MAX_STDOUT_BYTES * 0.4) # 40% head
|
||||
_STDOUT_TAIL_BYTES = MAX_STDOUT_BYTES - _STDOUT_HEAD_BYTES # 60% tail
|
||||
|
||||
# Background readers to avoid pipe buffer deadlocks
|
||||
def _drain(pipe, chunks, max_bytes):
|
||||
"""Simple head-only drain (used for stderr)."""
|
||||
total = 0
|
||||
try:
|
||||
while True:
|
||||
@@ -481,48 +475,8 @@ def execute_code(
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
stdout_total_bytes = [0] # mutable ref for total bytes seen
|
||||
|
||||
def _drain_head_tail(pipe, head_chunks, tail_chunks, head_bytes, tail_bytes, total_ref):
|
||||
"""Drain stdout keeping both head and tail data."""
|
||||
head_collected = 0
|
||||
from collections import deque
|
||||
tail_buf = deque()
|
||||
tail_collected = 0
|
||||
try:
|
||||
while True:
|
||||
data = pipe.read(4096)
|
||||
if not data:
|
||||
break
|
||||
total_ref[0] += len(data)
|
||||
# Fill head buffer first
|
||||
if head_collected < head_bytes:
|
||||
keep = min(len(data), head_bytes - head_collected)
|
||||
head_chunks.append(data[:keep])
|
||||
head_collected += keep
|
||||
data = data[keep:] # remaining goes to tail
|
||||
if not data:
|
||||
continue
|
||||
# Everything past head goes into rolling tail buffer
|
||||
tail_buf.append(data)
|
||||
tail_collected += len(data)
|
||||
# Evict old tail data to stay within tail_bytes budget
|
||||
while tail_collected > tail_bytes and tail_buf:
|
||||
oldest = tail_buf.popleft()
|
||||
tail_collected -= len(oldest)
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
# Transfer final tail to output list
|
||||
tail_chunks.extend(tail_buf)
|
||||
|
||||
stdout_head_chunks: list = []
|
||||
stdout_tail_chunks: list = []
|
||||
|
||||
stdout_reader = threading.Thread(
|
||||
target=_drain_head_tail,
|
||||
args=(proc.stdout, stdout_head_chunks, stdout_tail_chunks,
|
||||
_STDOUT_HEAD_BYTES, _STDOUT_TAIL_BYTES, stdout_total_bytes),
|
||||
daemon=True
|
||||
target=_drain, args=(proc.stdout, stdout_chunks, MAX_STDOUT_BYTES), daemon=True
|
||||
)
|
||||
stderr_reader = threading.Thread(
|
||||
target=_drain, args=(proc.stderr, stderr_chunks, MAX_STDERR_BYTES), daemon=True
|
||||
@@ -546,21 +500,12 @@ def execute_code(
|
||||
stdout_reader.join(timeout=3)
|
||||
stderr_reader.join(timeout=3)
|
||||
|
||||
stdout_head = b"".join(stdout_head_chunks).decode("utf-8", errors="replace")
|
||||
stdout_tail = b"".join(stdout_tail_chunks).decode("utf-8", errors="replace")
|
||||
stdout_text = b"".join(stdout_chunks).decode("utf-8", errors="replace")
|
||||
stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace")
|
||||
|
||||
# Assemble stdout with head+tail truncation
|
||||
total_stdout = stdout_total_bytes[0]
|
||||
if total_stdout > MAX_STDOUT_BYTES and stdout_tail:
|
||||
omitted = total_stdout - len(stdout_head) - len(stdout_tail)
|
||||
truncated_notice = (
|
||||
f"\n\n... [OUTPUT TRUNCATED - {omitted:,} chars omitted "
|
||||
f"out of {total_stdout:,} total] ...\n\n"
|
||||
)
|
||||
stdout_text = stdout_head + truncated_notice + stdout_tail
|
||||
else:
|
||||
stdout_text = stdout_head + stdout_tail
|
||||
# Truncation notice
|
||||
if len(stdout_text) >= MAX_STDOUT_BYTES:
|
||||
stdout_text = stdout_text[:MAX_STDOUT_BYTES] + "\n[output truncated at 50KB]"
|
||||
|
||||
exit_code = proc.returncode if proc.returncode is not None else -1
|
||||
duration = round(time.monotonic() - exec_start, 2)
|
||||
|
||||
Reference in New Issue
Block a user