fix(streaming): filter <think> blocks from gateway stream consumer
Models like MiniMax emit inline <think>...</think> reasoning blocks in their content field. The CLI already suppresses these via a state machine in _stream_delta, but the gateway's GatewayStreamConsumer had no equivalent filtering — raw think blocks were streamed directly to Discord/Telegram/Slack. The fix adds a _filter_and_accumulate() method that mirrors the CLI's approach: a state machine tracks whether we're inside a reasoning block and silently discards the content. Includes the same block-boundary check (tag must appear at line start or after whitespace-only prefix) to avoid false positives when models mention <think> in prose. Handles all tag variants: <think>, <thinking>, <THINKING>, <thought>, <reasoning>, <REASONING_SCRATCHPAD>. Also handles edge cases: - Tags split across streaming deltas (partial tag buffering) - Unclosed blocks (content suppressed until stream ends) - Multiple consecutive blocks - _flush_think_buffer on stream end for held-back partial tags Adds 22 unit tests + 1 integration test covering all scenarios.
This commit is contained in:
@@ -64,6 +64,18 @@ class GatewayStreamConsumer:
|
||||
# progressive edits for the remainder of the stream.
|
||||
_MAX_FLOOD_STRIKES = 3
|
||||
|
||||
# Reasoning/thinking tags that models emit inline in content.
|
||||
# Must stay in sync with cli.py _OPEN_TAGS/_CLOSE_TAGS and
|
||||
# run_agent.py _strip_think_blocks() tag variants.
|
||||
_OPEN_THINK_TAGS = (
|
||||
"<REASONING_SCRATCHPAD>", "<think>", "<reasoning>",
|
||||
"<THINKING>", "<thinking>", "<thought>",
|
||||
)
|
||||
_CLOSE_THINK_TAGS = (
|
||||
"</REASONING_SCRATCHPAD>", "</think>", "</reasoning>",
|
||||
"</THINKING>", "</thinking>", "</thought>",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: Any,
|
||||
@@ -88,6 +100,10 @@ class GatewayStreamConsumer:
|
||||
self._current_edit_interval = self.cfg.edit_interval # Adaptive backoff
|
||||
self._final_response_sent = False
|
||||
|
||||
# Think-block filter state (mirrors CLI's _stream_delta tag suppression)
|
||||
self._in_think_block = False
|
||||
self._think_buffer = ""
|
||||
|
||||
@property
|
||||
def already_sent(self) -> bool:
|
||||
"""True if at least one message was sent or edited during the run."""
|
||||
@@ -132,6 +148,112 @@ class GatewayStreamConsumer:
|
||||
"""Signal that the stream is complete."""
|
||||
self._queue.put(_DONE)
|
||||
|
||||
# ── Think-block filtering ────────────────────────────────────────
|
||||
# Models like MiniMax emit inline <think>...</think> blocks in their
|
||||
# content. The CLI's _stream_delta suppresses these via a state
|
||||
# machine; we do the same here so gateway users never see raw
|
||||
# reasoning tags. The agent also strips them from the final
|
||||
# response (run_agent.py _strip_think_blocks), but the stream
|
||||
# consumer sends intermediate edits before that stripping happens.
|
||||
|
||||
def _filter_and_accumulate(self, text: str) -> None:
|
||||
"""Add a text delta to the accumulated buffer, suppressing think blocks.
|
||||
|
||||
Uses a state machine that tracks whether we are inside a
|
||||
reasoning/thinking block. Text inside such blocks is silently
|
||||
discarded. Partial tags at buffer boundaries are held back in
|
||||
``_think_buffer`` until enough characters arrive to decide.
|
||||
"""
|
||||
buf = self._think_buffer + text
|
||||
self._think_buffer = ""
|
||||
|
||||
while buf:
|
||||
if self._in_think_block:
|
||||
# Look for the earliest closing tag
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for tag in self._CLOSE_THINK_TAGS:
|
||||
idx = buf.find(tag)
|
||||
if idx != -1 and (best_idx == -1 or idx < best_idx):
|
||||
best_idx = idx
|
||||
best_len = len(tag)
|
||||
|
||||
if best_len:
|
||||
# Found closing tag — discard block, process remainder
|
||||
self._in_think_block = False
|
||||
buf = buf[best_idx + best_len:]
|
||||
else:
|
||||
# No closing tag yet — hold tail that could be a
|
||||
# partial closing tag prefix, discard the rest.
|
||||
max_tag = max(len(t) for t in self._CLOSE_THINK_TAGS)
|
||||
self._think_buffer = buf[-max_tag:] if len(buf) > max_tag else buf
|
||||
return
|
||||
else:
|
||||
# Look for earliest opening tag at a block boundary
|
||||
# (start of text / preceded by newline + optional whitespace).
|
||||
# This prevents false positives when models *mention* tags
|
||||
# in prose (e.g. "the <think> tag is used for…").
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for tag in self._OPEN_THINK_TAGS:
|
||||
search_start = 0
|
||||
while True:
|
||||
idx = buf.find(tag, search_start)
|
||||
if idx == -1:
|
||||
break
|
||||
# Block-boundary check (mirrors cli.py logic)
|
||||
if idx == 0:
|
||||
is_boundary = (
|
||||
not self._accumulated
|
||||
or self._accumulated.endswith("\n")
|
||||
)
|
||||
else:
|
||||
preceding = buf[:idx]
|
||||
last_nl = preceding.rfind("\n")
|
||||
if last_nl == -1:
|
||||
is_boundary = (
|
||||
(not self._accumulated
|
||||
or self._accumulated.endswith("\n"))
|
||||
and preceding.strip() == ""
|
||||
)
|
||||
else:
|
||||
is_boundary = preceding[last_nl + 1:].strip() == ""
|
||||
|
||||
if is_boundary and (best_idx == -1 or idx < best_idx):
|
||||
best_idx = idx
|
||||
best_len = len(tag)
|
||||
break # first boundary hit for this tag is enough
|
||||
search_start = idx + 1
|
||||
|
||||
if best_len:
|
||||
# Emit text before the tag, enter think block
|
||||
self._accumulated += buf[:best_idx]
|
||||
self._in_think_block = True
|
||||
buf = buf[best_idx + best_len:]
|
||||
else:
|
||||
# No opening tag — check for a partial tag at the tail
|
||||
held_back = 0
|
||||
for tag in self._OPEN_THINK_TAGS:
|
||||
for i in range(1, len(tag)):
|
||||
if buf.endswith(tag[:i]) and i > held_back:
|
||||
held_back = i
|
||||
if held_back:
|
||||
self._accumulated += buf[:-held_back]
|
||||
self._think_buffer = buf[-held_back:]
|
||||
else:
|
||||
self._accumulated += buf
|
||||
return
|
||||
|
||||
def _flush_think_buffer(self) -> None:
|
||||
"""Flush any held-back partial-tag buffer into accumulated text.
|
||||
|
||||
Called when the stream ends (got_done) so that partial text that
|
||||
was held back waiting for a possible opening tag is not lost.
|
||||
"""
|
||||
if self._think_buffer and not self._in_think_block:
|
||||
self._accumulated += self._think_buffer
|
||||
self._think_buffer = ""
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Async task that drains the queue and edits the platform message."""
|
||||
# Platform message length limit — leave room for cursor + formatting
|
||||
@@ -156,10 +278,16 @@ class GatewayStreamConsumer:
|
||||
if isinstance(item, tuple) and len(item) == 2 and item[0] is _COMMENTARY:
|
||||
commentary_text = item[1]
|
||||
break
|
||||
self._accumulated += item
|
||||
self._filter_and_accumulate(item)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Flush any held-back partial-tag buffer on stream end
|
||||
# so trailing text that was waiting for a potential open
|
||||
# tag is not lost.
|
||||
if got_done:
|
||||
self._flush_think_buffer()
|
||||
|
||||
# Decide whether to flush an edit
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_edit_time
|
||||
|
||||
@@ -680,3 +680,202 @@ class TestCancelledConsumerSetsFlags:
|
||||
# Without a successful send, final_response_sent should stay False
|
||||
# so the normal gateway send path can deliver the response.
|
||||
assert consumer.final_response_sent is False
|
||||
|
||||
|
||||
# ── Think-block filtering unit tests ─────────────────────────────────────
|
||||
|
||||
|
||||
def _make_consumer() -> GatewayStreamConsumer:
|
||||
"""Create a bare consumer for unit-testing the filter (no adapter needed)."""
|
||||
adapter = MagicMock()
|
||||
return GatewayStreamConsumer(adapter, "chat_test")
|
||||
|
||||
|
||||
class TestFilterAndAccumulate:
|
||||
"""Unit tests for _filter_and_accumulate think-block suppression."""
|
||||
|
||||
def test_plain_text_passes_through(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("Hello world")
|
||||
assert c._accumulated == "Hello world"
|
||||
|
||||
def test_complete_think_block_stripped(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("<think>internal reasoning</think>Answer here")
|
||||
assert c._accumulated == "Answer here"
|
||||
|
||||
def test_think_block_in_middle(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("Prefix\n<think>reasoning</think>\nSuffix")
|
||||
assert c._accumulated == "Prefix\n\nSuffix"
|
||||
|
||||
def test_think_block_split_across_deltas(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("<think>start of")
|
||||
c._filter_and_accumulate(" reasoning</think>visible text")
|
||||
assert c._accumulated == "visible text"
|
||||
|
||||
def test_opening_tag_split_across_deltas(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("<thi")
|
||||
# Partial tag held back
|
||||
assert c._accumulated == ""
|
||||
c._filter_and_accumulate("nk>hidden</think>shown")
|
||||
assert c._accumulated == "shown"
|
||||
|
||||
def test_closing_tag_split_across_deltas(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("<think>hidden</thi")
|
||||
assert c._accumulated == ""
|
||||
c._filter_and_accumulate("nk>shown")
|
||||
assert c._accumulated == "shown"
|
||||
|
||||
def test_multiple_think_blocks(self):
|
||||
c = _make_consumer()
|
||||
# Consecutive blocks with no text between them — both stripped
|
||||
c._filter_and_accumulate(
|
||||
"<think>block1</think><think>block2</think>visible"
|
||||
)
|
||||
assert c._accumulated == "visible"
|
||||
|
||||
def test_multiple_think_blocks_with_text_between(self):
|
||||
"""Think tag after non-whitespace is NOT a boundary (prose safety)."""
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate(
|
||||
"<think>block1</think>A<think>block2</think>B"
|
||||
)
|
||||
# Second <think> follows 'A' (not a block boundary) — treated as prose
|
||||
assert "A" in c._accumulated
|
||||
assert "B" in c._accumulated
|
||||
|
||||
def test_thinking_tag_variant(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("<thinking>deep thought</thinking>Result")
|
||||
assert c._accumulated == "Result"
|
||||
|
||||
def test_thought_tag_variant(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("<thought>Gemma style</thought>Output")
|
||||
assert c._accumulated == "Output"
|
||||
|
||||
def test_reasoning_scratchpad_variant(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate(
|
||||
"<REASONING_SCRATCHPAD>long plan</REASONING_SCRATCHPAD>Done"
|
||||
)
|
||||
assert c._accumulated == "Done"
|
||||
|
||||
def test_case_insensitive_THINKING(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("<THINKING>caps</THINKING>answer")
|
||||
assert c._accumulated == "answer"
|
||||
|
||||
def test_prose_mention_not_stripped(self):
|
||||
"""<think> mentioned mid-line in prose should NOT trigger filtering."""
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("The <think> tag is used for reasoning")
|
||||
assert "<think>" in c._accumulated
|
||||
assert "used for reasoning" in c._accumulated
|
||||
|
||||
def test_prose_mention_after_text(self):
|
||||
"""<think> after non-whitespace on same line is not a block boundary."""
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("Try using <think>some content</think> tags")
|
||||
assert "<think>" in c._accumulated
|
||||
|
||||
def test_think_at_line_start_is_stripped(self):
|
||||
"""<think> at start of a new line IS a block boundary."""
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("Previous line\n<think>reasoning</think>Next")
|
||||
assert "Previous line\nNext" == c._accumulated
|
||||
|
||||
def test_think_with_only_whitespace_before(self):
|
||||
"""<think> preceded by only whitespace on its line is a boundary."""
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate(" <think>hidden</think>visible")
|
||||
# Leading whitespace before the tag is emitted, then block is stripped
|
||||
assert c._accumulated == " visible"
|
||||
|
||||
def test_flush_think_buffer_on_non_tag(self):
|
||||
"""Partial tag that turns out not to be a tag is flushed."""
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("<thi")
|
||||
assert c._accumulated == ""
|
||||
# Flush explicitly (simulates stream end)
|
||||
c._flush_think_buffer()
|
||||
assert c._accumulated == "<thi"
|
||||
|
||||
def test_flush_think_buffer_when_inside_block(self):
|
||||
"""Flush while inside a think block does NOT emit buffered content."""
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("<think>still thinking")
|
||||
c._flush_think_buffer()
|
||||
assert c._accumulated == ""
|
||||
|
||||
def test_unclosed_think_block_suppresses(self):
|
||||
"""An unclosed <think> suppresses all subsequent content."""
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("Before\n<think>reasoning that never ends...")
|
||||
assert c._accumulated == "Before\n"
|
||||
|
||||
def test_multiline_think_block(self):
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate(
|
||||
"<think>\nLine 1\nLine 2\nLine 3\n</think>Final answer"
|
||||
)
|
||||
assert c._accumulated == "Final answer"
|
||||
|
||||
def test_segment_reset_preserves_think_state(self):
|
||||
"""_reset_segment_state should NOT clear think-block filter state."""
|
||||
c = _make_consumer()
|
||||
c._filter_and_accumulate("<think>start")
|
||||
c._reset_segment_state()
|
||||
# Still inside think block — subsequent text should be suppressed
|
||||
c._filter_and_accumulate("still hidden</think>visible")
|
||||
assert c._accumulated == "visible"
|
||||
|
||||
|
||||
class TestFilterAndAccumulateIntegration:
|
||||
"""Integration: verify think blocks don't leak through the full run() path."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_think_block_not_sent_to_platform(self):
|
||||
"""Think blocks should be filtered before platform edit."""
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(
|
||||
return_value=SimpleNamespace(success=True, message_id="msg_1")
|
||||
)
|
||||
adapter.edit_message = AsyncMock(
|
||||
return_value=SimpleNamespace(success=True)
|
||||
)
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_test",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
|
||||
)
|
||||
|
||||
# Simulate streaming: think block then visible text
|
||||
consumer.on_delta("<think>deep reasoning here</think>")
|
||||
consumer.on_delta("The answer is 42.")
|
||||
consumer.finish()
|
||||
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
# The final text sent to the platform should NOT contain <think>
|
||||
all_calls = list(adapter.send.call_args_list) + list(
|
||||
adapter.edit_message.call_args_list
|
||||
)
|
||||
for call in all_calls:
|
||||
args, kwargs = call
|
||||
content = kwargs.get("content") or (args[0] if args else "")
|
||||
assert "<think>" not in content, f"Think tag leaked: {content}"
|
||||
assert "deep reasoning" not in content
|
||||
|
||||
try:
|
||||
task.cancel()
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user