Merge pull request #176 from Bartok9/fix-tts-voice-accumulation
fix(gateway): prevent TTS voice messages from accumulating across turns
This commit is contained in:
@@ -1739,6 +1739,9 @@ class GatewayRunner:
|
||||
content = f"[Delivered from {mirror_src}] {content}"
|
||||
agent_history.append({"role": role, "content": content})
|
||||
|
||||
# Track history length to only scan NEW messages for MEDIA tags
|
||||
history_len = len(agent_history)
|
||||
|
||||
result = agent.run_conversation(message, conversation_history=agent_history)
|
||||
result_holder[0] = result
|
||||
|
||||
@@ -1759,10 +1762,17 @@ class GatewayRunner:
|
||||
# doesn't include them. We collect unique tags from tool results and
|
||||
# append any that aren't already present in the final response, so the
|
||||
# adapter's extract_media() can find and deliver the files exactly once.
|
||||
#
|
||||
# IMPORTANT: Only scan messages from the CURRENT turn (after history_len),
|
||||
# not the full history. This prevents TTS voice messages from earlier
|
||||
# turns being re-attached to every subsequent reply. (Fixes #160)
|
||||
if "MEDIA:" not in final_response:
|
||||
media_tags = []
|
||||
has_voice_directive = False
|
||||
for msg in result.get("messages", []):
|
||||
all_messages = result.get("messages", [])
|
||||
# Only process new messages from this turn
|
||||
new_messages = all_messages[history_len:] if len(all_messages) > history_len else []
|
||||
for msg in new_messages:
|
||||
if msg.get("role") == "tool" or msg.get("role") == "function":
|
||||
content = msg.get("content", "")
|
||||
if "MEDIA:" in content:
|
||||
|
||||
184
tests/gateway/test_media_extraction.py
Normal file
184
tests/gateway/test_media_extraction.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Tests for MEDIA tag extraction from tool results.
|
||||
|
||||
Verifies that MEDIA tags (e.g., from TTS tool) are only extracted from
|
||||
messages in the CURRENT turn, not from the full conversation history.
|
||||
This prevents voice messages from accumulating and being sent multiple
|
||||
times per reply. (Regression test for #160)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import re
|
||||
|
||||
|
||||
def extract_media_tags_fixed(result_messages, history_len):
|
||||
"""
|
||||
Extract MEDIA tags from tool results, but ONLY from new messages
|
||||
(those added after history_len). This is the fixed behavior.
|
||||
|
||||
Args:
|
||||
result_messages: Full list of messages including history + new
|
||||
history_len: Length of history before this turn
|
||||
|
||||
Returns:
|
||||
Tuple of (media_tags list, has_voice_directive bool)
|
||||
"""
|
||||
media_tags = []
|
||||
has_voice_directive = False
|
||||
|
||||
# Only process new messages from this turn
|
||||
new_messages = result_messages[history_len:] if len(result_messages) > history_len else []
|
||||
|
||||
for msg in new_messages:
|
||||
if msg.get("role") == "tool" or msg.get("role") == "function":
|
||||
content = msg.get("content", "")
|
||||
if "MEDIA:" in content:
|
||||
for match in re.finditer(r'MEDIA:(\S+)', content):
|
||||
path = match.group(1).strip().rstrip('",}')
|
||||
if path:
|
||||
media_tags.append(f"MEDIA:{path}")
|
||||
if "[[audio_as_voice]]" in content:
|
||||
has_voice_directive = True
|
||||
|
||||
return media_tags, has_voice_directive
|
||||
|
||||
|
||||
def extract_media_tags_broken(result_messages):
|
||||
"""
|
||||
The BROKEN behavior: extract MEDIA tags from ALL messages including history.
|
||||
This causes TTS voice messages to accumulate and be re-sent on every reply.
|
||||
"""
|
||||
media_tags = []
|
||||
has_voice_directive = False
|
||||
|
||||
for msg in result_messages:
|
||||
if msg.get("role") == "tool" or msg.get("role") == "function":
|
||||
content = msg.get("content", "")
|
||||
if "MEDIA:" in content:
|
||||
for match in re.finditer(r'MEDIA:(\S+)', content):
|
||||
path = match.group(1).strip().rstrip('",}')
|
||||
if path:
|
||||
media_tags.append(f"MEDIA:{path}")
|
||||
if "[[audio_as_voice]]" in content:
|
||||
has_voice_directive = True
|
||||
|
||||
return media_tags, has_voice_directive
|
||||
|
||||
|
||||
class TestMediaExtraction:
|
||||
"""Tests for MEDIA tag extraction from tool results."""
|
||||
|
||||
def test_media_tags_not_extracted_from_history(self):
|
||||
"""MEDIA tags from previous turns should NOT be extracted again."""
|
||||
# Simulate conversation history with a TTS call from a previous turn
|
||||
history = [
|
||||
{"role": "user", "content": "Say hello as audio"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "1", "function": {"name": "text_to_speech"}}]},
|
||||
{"role": "tool", "tool_call_id": "1", "content": '{"success": true, "media_tag": "[[audio_as_voice]]\\nMEDIA:/path/to/audio1.ogg"}'},
|
||||
{"role": "assistant", "content": "I've said hello for you!"},
|
||||
]
|
||||
|
||||
# New turn: user asks a simple question
|
||||
new_messages = [
|
||||
{"role": "user", "content": "What time is it?"},
|
||||
{"role": "assistant", "content": "It's 3:30 AM."},
|
||||
]
|
||||
|
||||
all_messages = history + new_messages
|
||||
history_len = len(history)
|
||||
|
||||
# Fixed behavior: should extract NO media tags (none in new messages)
|
||||
tags, voice_directive = extract_media_tags_fixed(all_messages, history_len)
|
||||
assert tags == [], "Fixed extraction should not find tags in history"
|
||||
assert voice_directive is False
|
||||
|
||||
# Broken behavior: would incorrectly extract the old media tag
|
||||
broken_tags, broken_voice = extract_media_tags_broken(all_messages)
|
||||
assert len(broken_tags) == 1, "Broken extraction finds tags in history"
|
||||
assert "audio1.ogg" in broken_tags[0]
|
||||
|
||||
def test_media_tags_extracted_from_current_turn(self):
|
||||
"""MEDIA tags from the current turn SHOULD be extracted."""
|
||||
# History without TTS
|
||||
history = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
# New turn with TTS call
|
||||
new_messages = [
|
||||
{"role": "user", "content": "Say goodbye as audio"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "2", "function": {"name": "text_to_speech"}}]},
|
||||
{"role": "tool", "tool_call_id": "2", "content": '{"success": true, "media_tag": "[[audio_as_voice]]\\nMEDIA:/path/to/audio2.ogg"}'},
|
||||
{"role": "assistant", "content": "I've said goodbye!"},
|
||||
]
|
||||
|
||||
all_messages = history + new_messages
|
||||
history_len = len(history)
|
||||
|
||||
# Fixed behavior: should extract the new media tag
|
||||
tags, voice_directive = extract_media_tags_fixed(all_messages, history_len)
|
||||
assert len(tags) == 1, "Should extract media tag from current turn"
|
||||
assert "audio2.ogg" in tags[0]
|
||||
assert voice_directive is True
|
||||
|
||||
def test_multiple_tts_calls_in_history_not_accumulated(self):
|
||||
"""Multiple TTS calls in history should NOT accumulate in new responses."""
|
||||
# History with multiple TTS calls
|
||||
history = [
|
||||
{"role": "user", "content": "Say hello"},
|
||||
{"role": "tool", "tool_call_id": "1", "content": 'MEDIA:/audio/hello.ogg'},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
{"role": "user", "content": "Say goodbye"},
|
||||
{"role": "tool", "tool_call_id": "2", "content": 'MEDIA:/audio/goodbye.ogg'},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
{"role": "user", "content": "Say thanks"},
|
||||
{"role": "tool", "tool_call_id": "3", "content": 'MEDIA:/audio/thanks.ogg'},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
]
|
||||
|
||||
# New turn: no TTS
|
||||
new_messages = [
|
||||
{"role": "user", "content": "What time is it?"},
|
||||
{"role": "assistant", "content": "3 PM"},
|
||||
]
|
||||
|
||||
all_messages = history + new_messages
|
||||
history_len = len(history)
|
||||
|
||||
# Fixed: no tags
|
||||
tags, _ = extract_media_tags_fixed(all_messages, history_len)
|
||||
assert tags == [], "Should not accumulate tags from history"
|
||||
|
||||
# Broken: would have 3 tags (all the old ones)
|
||||
broken_tags, _ = extract_media_tags_broken(all_messages)
|
||||
assert len(broken_tags) == 3, "Broken version accumulates all history tags"
|
||||
|
||||
def test_deduplication_within_current_turn(self):
|
||||
"""Multiple MEDIA tags in current turn should be deduplicated."""
|
||||
history = []
|
||||
|
||||
# Current turn with multiple tool calls producing same media
|
||||
new_messages = [
|
||||
{"role": "user", "content": "Multiple TTS"},
|
||||
{"role": "tool", "tool_call_id": "1", "content": 'MEDIA:/audio/same.ogg'},
|
||||
{"role": "tool", "tool_call_id": "2", "content": 'MEDIA:/audio/same.ogg'}, # duplicate
|
||||
{"role": "tool", "tool_call_id": "3", "content": 'MEDIA:/audio/different.ogg'},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
]
|
||||
|
||||
all_messages = history + new_messages
|
||||
|
||||
tags, _ = extract_media_tags_fixed(all_messages, 0)
|
||||
# Even though same.ogg appears twice, deduplication happens after extraction
|
||||
# The extraction itself should get both, then caller deduplicates
|
||||
assert len(tags) == 3 # Raw extraction gets all
|
||||
|
||||
# Deduplication as done in the actual code:
|
||||
seen = set()
|
||||
unique = [t for t in tags if t not in seen and not seen.add(t)]
|
||||
assert len(unique) == 2 # After dedup: same.ogg and different.ogg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user