The summary message was always injected as 'user' role, which causes consecutive user messages when the last preserved head message is also 'user'. Some APIs reject this (400 error), and it produces malformed training data. Fix: check the role of the last head message and pick the opposite role for the summary — 'user' after assistant/tool, 'assistant' after user. Based on PR #328 by johnh4098. Closes #328.
323 lines
14 KiB
Python
323 lines
14 KiB
Python
"""Tests for agent/context_compressor.py — compression logic, thresholds, truncation fallback."""
|
|
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from agent.context_compressor import ContextCompressor
|
|
|
|
|
|
@pytest.fixture()
|
|
def compressor():
|
|
"""Create a ContextCompressor with mocked dependencies."""
|
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
|
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(None, None)):
|
|
c = ContextCompressor(
|
|
model="test/model",
|
|
threshold_percent=0.85,
|
|
protect_first_n=2,
|
|
protect_last_n=2,
|
|
quiet_mode=True,
|
|
)
|
|
return c
|
|
|
|
|
|
class TestShouldCompress:
|
|
def test_below_threshold(self, compressor):
|
|
compressor.last_prompt_tokens = 50000
|
|
assert compressor.should_compress() is False
|
|
|
|
def test_above_threshold(self, compressor):
|
|
compressor.last_prompt_tokens = 90000
|
|
assert compressor.should_compress() is True
|
|
|
|
def test_exact_threshold(self, compressor):
|
|
compressor.last_prompt_tokens = 85000
|
|
assert compressor.should_compress() is True
|
|
|
|
def test_explicit_tokens(self, compressor):
|
|
assert compressor.should_compress(prompt_tokens=90000) is True
|
|
assert compressor.should_compress(prompt_tokens=50000) is False
|
|
|
|
|
|
class TestShouldCompressPreflight:
|
|
def test_short_messages(self, compressor):
|
|
msgs = [{"role": "user", "content": "short"}]
|
|
assert compressor.should_compress_preflight(msgs) is False
|
|
|
|
def test_long_messages(self, compressor):
|
|
# Each message ~100k chars / 4 = 25k tokens, need >85k threshold
|
|
msgs = [{"role": "user", "content": "x" * 400000}]
|
|
assert compressor.should_compress_preflight(msgs) is True
|
|
|
|
|
|
class TestUpdateFromResponse:
|
|
def test_updates_fields(self, compressor):
|
|
compressor.update_from_response({
|
|
"prompt_tokens": 5000,
|
|
"completion_tokens": 1000,
|
|
"total_tokens": 6000,
|
|
})
|
|
assert compressor.last_prompt_tokens == 5000
|
|
assert compressor.last_completion_tokens == 1000
|
|
assert compressor.last_total_tokens == 6000
|
|
|
|
def test_missing_fields_default_zero(self, compressor):
|
|
compressor.update_from_response({})
|
|
assert compressor.last_prompt_tokens == 0
|
|
|
|
|
|
class TestGetStatus:
|
|
def test_returns_expected_keys(self, compressor):
|
|
status = compressor.get_status()
|
|
assert "last_prompt_tokens" in status
|
|
assert "threshold_tokens" in status
|
|
assert "context_length" in status
|
|
assert "usage_percent" in status
|
|
assert "compression_count" in status
|
|
|
|
def test_usage_percent_calculation(self, compressor):
|
|
compressor.last_prompt_tokens = 50000
|
|
status = compressor.get_status()
|
|
assert status["usage_percent"] == 50.0
|
|
|
|
|
|
class TestCompress:
|
|
def _make_messages(self, n):
|
|
return [{"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"} for i in range(n)]
|
|
|
|
def test_too_few_messages_returns_unchanged(self, compressor):
|
|
msgs = self._make_messages(4) # protect_first=2 + protect_last=2 + 1 = 5 needed
|
|
result = compressor.compress(msgs)
|
|
assert result == msgs
|
|
|
|
def test_truncation_fallback_no_client(self, compressor):
|
|
# compressor has client=None, so should use truncation fallback
|
|
msgs = [{"role": "system", "content": "System prompt"}] + self._make_messages(10)
|
|
result = compressor.compress(msgs)
|
|
assert len(result) < len(msgs)
|
|
# Should keep system message and last N
|
|
assert result[0]["role"] == "system"
|
|
assert compressor.compression_count == 1
|
|
|
|
def test_compression_increments_count(self, compressor):
|
|
msgs = self._make_messages(10)
|
|
compressor.compress(msgs)
|
|
assert compressor.compression_count == 1
|
|
compressor.compress(msgs)
|
|
assert compressor.compression_count == 2
|
|
|
|
def test_protects_first_and_last(self, compressor):
|
|
msgs = self._make_messages(10)
|
|
result = compressor.compress(msgs)
|
|
# First 2 messages should be preserved (protect_first_n=2)
|
|
# Last 2 messages should be preserved (protect_last_n=2)
|
|
assert result[-1]["content"] == msgs[-1]["content"]
|
|
assert result[-2]["content"] == msgs[-2]["content"]
|
|
|
|
|
|
class TestGenerateSummaryNoneContent:
|
|
"""Regression: content=None (from tool-call-only assistant messages) must not crash."""
|
|
|
|
def test_none_content_does_not_crash(self):
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: tool calls happened"
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|
|
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
|
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
|
c = ContextCompressor(model="test", quiet_mode=True)
|
|
|
|
messages = [
|
|
{"role": "user", "content": "do something"},
|
|
{"role": "assistant", "content": None, "tool_calls": [
|
|
{"function": {"name": "search"}}
|
|
]},
|
|
{"role": "tool", "content": "result"},
|
|
{"role": "assistant", "content": None},
|
|
{"role": "user", "content": "thanks"},
|
|
]
|
|
|
|
summary = c._generate_summary(messages)
|
|
assert isinstance(summary, str)
|
|
assert "CONTEXT SUMMARY" in summary
|
|
|
|
def test_none_content_in_system_message_compress(self):
|
|
"""System message with content=None should not crash during compress."""
|
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
|
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(None, None)):
|
|
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
|
|
|
msgs = [{"role": "system", "content": None}] + [
|
|
{"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"}
|
|
for i in range(10)
|
|
]
|
|
result = c.compress(msgs)
|
|
assert len(result) < len(msgs)
|
|
|
|
|
|
class TestCompressWithClient:
|
|
def test_summarization_path(self):
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: stuff happened"
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|
|
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
|
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
|
c = ContextCompressor(model="test", quiet_mode=True)
|
|
|
|
msgs = [{"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"} for i in range(10)]
|
|
result = c.compress(msgs)
|
|
|
|
# Should have summary message in the middle
|
|
contents = [m.get("content", "") for m in result]
|
|
assert any("CONTEXT SUMMARY" in c for c in contents)
|
|
assert len(result) < len(msgs)
|
|
|
|
def test_summarization_does_not_split_tool_call_pairs(self):
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: compressed middle"
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|
|
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
|
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
|
c = ContextCompressor(
|
|
model="test",
|
|
quiet_mode=True,
|
|
protect_first_n=3,
|
|
protect_last_n=4,
|
|
)
|
|
|
|
msgs = [
|
|
{"role": "user", "content": "Could you address the reviewer comments in PR#71"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{"id": "call_a", "type": "function", "function": {"name": "skill_view", "arguments": "{}"}},
|
|
{"id": "call_b", "type": "function", "function": {"name": "skill_view", "arguments": "{}"}},
|
|
],
|
|
},
|
|
{"role": "tool", "tool_call_id": "call_a", "content": "output a"},
|
|
{"role": "tool", "tool_call_id": "call_b", "content": "output b"},
|
|
{"role": "user", "content": "later 1"},
|
|
{"role": "assistant", "content": "later 2"},
|
|
{"role": "tool", "tool_call_id": "call_x", "content": "later output"},
|
|
{"role": "assistant", "content": "later 3"},
|
|
{"role": "user", "content": "later 4"},
|
|
]
|
|
|
|
result = c.compress(msgs)
|
|
|
|
answered_ids = {
|
|
msg.get("tool_call_id")
|
|
for msg in result
|
|
if msg.get("role") == "tool" and msg.get("tool_call_id")
|
|
}
|
|
for msg in result:
|
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
for tc in msg["tool_calls"]:
|
|
assert tc["id"] in answered_ids
|
|
|
|
def test_summary_role_avoids_consecutive_user_messages(self):
|
|
"""Summary role should alternate with the last head message to avoid consecutive same-role messages."""
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: stuff happened"
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|
|
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
|
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
|
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
|
|
|
# Last head message (index 1) is "assistant" → summary should be "user"
|
|
msgs = [
|
|
{"role": "user", "content": "msg 0"},
|
|
{"role": "assistant", "content": "msg 1"},
|
|
{"role": "user", "content": "msg 2"},
|
|
{"role": "assistant", "content": "msg 3"},
|
|
{"role": "user", "content": "msg 4"},
|
|
{"role": "assistant", "content": "msg 5"},
|
|
]
|
|
result = c.compress(msgs)
|
|
summary_msg = [m for m in result if "CONTEXT SUMMARY" in (m.get("content") or "")]
|
|
assert len(summary_msg) == 1
|
|
assert summary_msg[0]["role"] == "user"
|
|
|
|
def test_summary_role_avoids_consecutive_user_when_head_ends_with_user(self):
|
|
"""When last head message is 'user', summary must be 'assistant' to avoid two consecutive user messages."""
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: stuff happened"
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|
|
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
|
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
|
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=3, protect_last_n=2)
|
|
|
|
# Last head message (index 2) is "user" → summary should be "assistant"
|
|
msgs = [
|
|
{"role": "system", "content": "system prompt"},
|
|
{"role": "user", "content": "msg 1"},
|
|
{"role": "user", "content": "msg 2"}, # last head — user
|
|
{"role": "assistant", "content": "msg 3"},
|
|
{"role": "user", "content": "msg 4"},
|
|
{"role": "assistant", "content": "msg 5"},
|
|
{"role": "user", "content": "msg 6"},
|
|
{"role": "assistant", "content": "msg 7"},
|
|
]
|
|
result = c.compress(msgs)
|
|
summary_msg = [m for m in result if "CONTEXT SUMMARY" in (m.get("content") or "")]
|
|
assert len(summary_msg) == 1
|
|
assert summary_msg[0]["role"] == "assistant"
|
|
|
|
def test_summarization_does_not_start_tail_with_tool_outputs(self):
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: compressed middle"
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|
|
|
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
|
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
|
c = ContextCompressor(
|
|
model="test",
|
|
quiet_mode=True,
|
|
protect_first_n=2,
|
|
protect_last_n=3,
|
|
)
|
|
|
|
msgs = [
|
|
{"role": "user", "content": "earlier 1"},
|
|
{"role": "assistant", "content": "earlier 2"},
|
|
{"role": "user", "content": "earlier 3"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{"id": "call_c", "type": "function", "function": {"name": "search_files", "arguments": "{}"}},
|
|
],
|
|
},
|
|
{"role": "tool", "tool_call_id": "call_c", "content": "output c"},
|
|
{"role": "user", "content": "latest user"},
|
|
]
|
|
|
|
result = c.compress(msgs)
|
|
|
|
called_ids = {
|
|
tc["id"]
|
|
for msg in result
|
|
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
|
for tc in msg["tool_calls"]
|
|
}
|
|
for msg in result:
|
|
if msg.get("role") == "tool" and msg.get("tool_call_id"):
|
|
assert msg["tool_call_id"] in called_ids
|