Compare commits

..

1 Commits

Author SHA1 Message Date
teknium1
f8240143b6 feat(discord): add DISCORD_ALLOW_BOTS config for bot message filtering (inspired by openclaw)
Add configurable bot message filtering via DISCORD_ALLOW_BOTS env var:

- 'none' (default): Ignore all other bot messages — matches previous
  behavior where only our own bot was filtered, but now ALL bots are
  filtered by default for cleaner channels
- 'mentions': Accept bot messages only when they @mention our bot —
  useful for bot-to-bot workflows triggered by mentions
- 'all': Accept all bot messages — for setups where bots need to
  interact freely

Previously, we only ignored our own bot's messages, allowing all other
bots through. This could cause noisy loops in channels with multiple bots.

8 new tests covering all filter modes and edge cases.

Inspired by openclaw v2026.3.7 Discord allowBots: 'mentions' config.
2026-03-09 02:20:57 -07:00
4 changed files with 139 additions and 114 deletions

View File

@@ -120,9 +120,23 @@ class DiscordAdapter(BasePlatformAdapter):
@self._client.event
async def on_message(message: DiscordMessage):
# Ignore bot's own messages
# Always ignore our own messages
if message.author == self._client.user:
return
# Bot message filtering (DISCORD_ALLOW_BOTS):
# "none" — ignore all other bots (default)
# "mentions" — accept bot messages only when they @mention us
# "all" — accept all bot messages
if getattr(message.author, "bot", False):
allow_bots = os.getenv("DISCORD_ALLOW_BOTS", "none").lower().strip()
if allow_bots == "none":
return
elif allow_bots == "mentions":
if not self._client.user or self._client.user not in message.mentions:
return
# "all" falls through to handle_message
await self._handle_message(message)
# Register slash commands

View File

@@ -0,0 +1,117 @@
"""Tests for Discord bot message filtering (DISCORD_ALLOW_BOTS)."""
import asyncio
import os
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
def _make_author(*, bot: bool = False, is_self: bool = False):
"""Create a mock Discord author."""
author = MagicMock()
author.bot = bot
author.id = 99999 if is_self else 12345
author.name = "TestBot" if bot else "TestUser"
author.display_name = author.name
return author
def _make_message(*, author=None, content="hello", mentions=None, is_dm=False):
"""Create a mock Discord message."""
msg = MagicMock()
msg.author = author or _make_author()
msg.content = content
msg.attachments = []
msg.mentions = mentions or []
if is_dm:
import discord
msg.channel = MagicMock(spec=discord.DMChannel)
msg.channel.id = 111
else:
msg.channel = MagicMock()
msg.channel.id = 222
msg.channel.name = "test-channel"
msg.channel.guild = MagicMock()
msg.channel.guild.name = "TestServer"
# Make isinstance checks fail for DMChannel and Thread
type(msg.channel).__name__ = "TextChannel"
return msg
class TestDiscordBotFilter(unittest.TestCase):
"""Test the DISCORD_ALLOW_BOTS filtering logic."""
def _run_filter(self, message, allow_bots="none", client_user=None):
"""Simulate the on_message filter logic and return whether message was accepted."""
# Replicate the exact filter logic from discord.py on_message
if message.author == client_user:
return False # own messages always ignored
if getattr(message.author, "bot", False):
allow = allow_bots.lower().strip()
if allow == "none":
return False
elif allow == "mentions":
if not client_user or client_user not in message.mentions:
return False
# "all" falls through
return True # message accepted
def test_own_messages_always_ignored(self):
"""Bot's own messages are always ignored regardless of allow_bots."""
bot_user = _make_author(is_self=True)
msg = _make_message(author=bot_user)
self.assertFalse(self._run_filter(msg, "all", bot_user))
def test_human_messages_always_accepted(self):
"""Human messages are always accepted regardless of allow_bots."""
human = _make_author(bot=False)
msg = _make_message(author=human)
self.assertTrue(self._run_filter(msg, "none"))
self.assertTrue(self._run_filter(msg, "mentions"))
self.assertTrue(self._run_filter(msg, "all"))
def test_allow_bots_none_rejects_bots(self):
"""With allow_bots=none, all other bot messages are rejected."""
bot = _make_author(bot=True)
msg = _make_message(author=bot)
self.assertFalse(self._run_filter(msg, "none"))
def test_allow_bots_all_accepts_bots(self):
"""With allow_bots=all, all bot messages are accepted."""
bot = _make_author(bot=True)
msg = _make_message(author=bot)
self.assertTrue(self._run_filter(msg, "all"))
def test_allow_bots_mentions_rejects_without_mention(self):
"""With allow_bots=mentions, bot messages without @mention are rejected."""
our_user = _make_author(is_self=True)
bot = _make_author(bot=True)
msg = _make_message(author=bot, mentions=[])
self.assertFalse(self._run_filter(msg, "mentions", our_user))
def test_allow_bots_mentions_accepts_with_mention(self):
"""With allow_bots=mentions, bot messages with @mention are accepted."""
our_user = _make_author(is_self=True)
bot = _make_author(bot=True)
msg = _make_message(author=bot, mentions=[our_user])
self.assertTrue(self._run_filter(msg, "mentions", our_user))
def test_default_is_none(self):
"""Default behavior (no env var) should be 'none'."""
default = os.getenv("DISCORD_ALLOW_BOTS", "none")
self.assertEqual(default, "none")
def test_case_insensitive(self):
"""Allow_bots value should be case-insensitive."""
bot = _make_author(bot=True)
msg = _make_message(author=bot)
self.assertTrue(self._run_filter(msg, "ALL"))
self.assertTrue(self._run_filter(msg, "All"))
self.assertFalse(self._run_filter(msg, "NONE"))
self.assertFalse(self._run_filter(msg, "None"))
if __name__ == "__main__":
unittest.main()

View File

@@ -393,56 +393,5 @@ class TestStubSchemaDrift(unittest.TestCase):
self.assertIn("mode", src)
class TestHeadTailTruncation(unittest.TestCase):
"""Tests for head+tail truncation of large stdout in execute_code."""
def _run(self, code):
with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call):
result = execute_code(
code=code,
task_id="test-task",
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
)
return json.loads(result)
def test_short_output_not_truncated(self):
"""Output under MAX_STDOUT_BYTES should not be truncated."""
result = self._run('print("small output")')
self.assertEqual(result["status"], "success")
self.assertIn("small output", result["output"])
self.assertNotIn("TRUNCATED", result["output"])
def test_large_output_preserves_head_and_tail(self):
"""Output exceeding MAX_STDOUT_BYTES keeps both head and tail."""
code = '''
# Print HEAD marker, then filler, then TAIL marker
print("HEAD_MARKER_START")
for i in range(15000):
print(f"filler_line_{i:06d}_padding_to_fill_buffer")
print("TAIL_MARKER_END")
'''
result = self._run(code)
self.assertEqual(result["status"], "success")
output = result["output"]
# Head should be preserved
self.assertIn("HEAD_MARKER_START", output)
# Tail should be preserved (this is the key improvement)
self.assertIn("TAIL_MARKER_END", output)
# Truncation notice should be present
self.assertIn("TRUNCATED", output)
def test_truncation_notice_format(self):
"""Truncation notice includes character counts."""
code = '''
for i in range(15000):
print(f"padding_line_{i:06d}_xxxxxxxxxxxxxxxxxxxxxxxxxx")
'''
result = self._run(code)
output = result["output"]
if "TRUNCATED" in output:
self.assertIn("chars omitted", output)
self.assertIn("total", output)
if __name__ == "__main__":
unittest.main()

View File

@@ -457,17 +457,11 @@ def execute_code(
# --- Poll loop: watch for exit, timeout, and interrupt ---
deadline = time.monotonic() + timeout
stdout_chunks: list = []
stderr_chunks: list = []
# Background readers to avoid pipe buffer deadlocks.
# For stdout we use a head+tail strategy: keep the first HEAD_BYTES
# and a rolling window of the last TAIL_BYTES so the final print()
# output is never lost. Stderr keeps head-only (errors appear early).
_STDOUT_HEAD_BYTES = int(MAX_STDOUT_BYTES * 0.4) # 40% head
_STDOUT_TAIL_BYTES = MAX_STDOUT_BYTES - _STDOUT_HEAD_BYTES # 60% tail
# Background readers to avoid pipe buffer deadlocks
def _drain(pipe, chunks, max_bytes):
"""Simple head-only drain (used for stderr)."""
total = 0
try:
while True:
@@ -481,48 +475,8 @@ def execute_code(
except (ValueError, OSError):
pass
stdout_total_bytes = [0] # mutable ref for total bytes seen
def _drain_head_tail(pipe, head_chunks, tail_chunks, head_bytes, tail_bytes, total_ref):
"""Drain stdout keeping both head and tail data."""
head_collected = 0
from collections import deque
tail_buf = deque()
tail_collected = 0
try:
while True:
data = pipe.read(4096)
if not data:
break
total_ref[0] += len(data)
# Fill head buffer first
if head_collected < head_bytes:
keep = min(len(data), head_bytes - head_collected)
head_chunks.append(data[:keep])
head_collected += keep
data = data[keep:] # remaining goes to tail
if not data:
continue
# Everything past head goes into rolling tail buffer
tail_buf.append(data)
tail_collected += len(data)
# Evict old tail data to stay within tail_bytes budget
while tail_collected > tail_bytes and tail_buf:
oldest = tail_buf.popleft()
tail_collected -= len(oldest)
except (ValueError, OSError):
pass
# Transfer final tail to output list
tail_chunks.extend(tail_buf)
stdout_head_chunks: list = []
stdout_tail_chunks: list = []
stdout_reader = threading.Thread(
target=_drain_head_tail,
args=(proc.stdout, stdout_head_chunks, stdout_tail_chunks,
_STDOUT_HEAD_BYTES, _STDOUT_TAIL_BYTES, stdout_total_bytes),
daemon=True
target=_drain, args=(proc.stdout, stdout_chunks, MAX_STDOUT_BYTES), daemon=True
)
stderr_reader = threading.Thread(
target=_drain, args=(proc.stderr, stderr_chunks, MAX_STDERR_BYTES), daemon=True
@@ -546,21 +500,12 @@ def execute_code(
stdout_reader.join(timeout=3)
stderr_reader.join(timeout=3)
stdout_head = b"".join(stdout_head_chunks).decode("utf-8", errors="replace")
stdout_tail = b"".join(stdout_tail_chunks).decode("utf-8", errors="replace")
stdout_text = b"".join(stdout_chunks).decode("utf-8", errors="replace")
stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace")
# Assemble stdout with head+tail truncation
total_stdout = stdout_total_bytes[0]
if total_stdout > MAX_STDOUT_BYTES and stdout_tail:
omitted = total_stdout - len(stdout_head) - len(stdout_tail)
truncated_notice = (
f"\n\n... [OUTPUT TRUNCATED - {omitted:,} chars omitted "
f"out of {total_stdout:,} total] ...\n\n"
)
stdout_text = stdout_head + truncated_notice + stdout_tail
else:
stdout_text = stdout_head + stdout_tail
# Truncation notice
if len(stdout_text) >= MAX_STDOUT_BYTES:
stdout_text = stdout_text[:MAX_STDOUT_BYTES] + "\n[output truncated at 50KB]"
exit_code = proc.returncode if proc.returncode is not None else -1
duration = round(time.monotonic() - exec_start, 2)