"""Regression tests for /retry replacement semantics.""" from unittest.mock import AsyncMock, MagicMock, patch import pytest from gateway.config import GatewayConfig from gateway.platforms.base import MessageEvent, MessageType from gateway.run import GatewayRunner from gateway.session import SessionStore @pytest.mark.asyncio async def test_gateway_retry_replaces_last_user_turn_in_transcript(tmp_path): config = GatewayConfig() with patch("gateway.session.SessionStore._ensure_loaded"): store = SessionStore(sessions_dir=tmp_path, config=config) store._db = None store._loaded = True session_id = "retry_session" for msg in [ {"role": "session_meta", "tools": []}, {"role": "user", "content": "first question"}, {"role": "assistant", "content": "first answer"}, {"role": "user", "content": "retry me"}, {"role": "assistant", "content": "old answer"}, ]: store.append_to_transcript(session_id, msg) gw = GatewayRunner.__new__(GatewayRunner) gw.config = config gw.session_store = store session_entry = MagicMock(session_id=session_id) session_entry.last_prompt_tokens = 111 gw.session_store.get_or_create_session = MagicMock(return_value=session_entry) async def fake_handle_message(event): assert event.text == "retry me" transcript_before = store.load_transcript(session_id) assert [m.get("content") for m in transcript_before if m.get("role") == "user"] == [ "first question" ] store.append_to_transcript(session_id, {"role": "user", "content": event.text}) store.append_to_transcript(session_id, {"role": "assistant", "content": "new answer"}) return "new answer" gw._handle_message = AsyncMock(side_effect=fake_handle_message) result = await gw._handle_retry_command( MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock()) ) assert result == "new answer" transcript_after = store.load_transcript(session_id) assert [m.get("content") for m in transcript_after if m.get("role") == "user"] == [ "first question", "retry me", ] assert [m.get("content") for m in transcript_after if m.get("role") == "assistant"] == [ "first answer", "new answer", ] @pytest.mark.asyncio async def test_gateway_retry_replays_original_text_not_retry_command(tmp_path): config = MagicMock() config.sessions_dir = tmp_path config.max_context_messages = 20 gw = GatewayRunner.__new__(GatewayRunner) gw.config = config gw.session_store = MagicMock() session_entry = MagicMock(session_id="test-session") session_entry.last_prompt_tokens = 55 gw.session_store.get_or_create_session.return_value = session_entry gw.session_store.load_transcript.return_value = [ {"role": "user", "content": "real message"}, {"role": "assistant", "content": "answer"}, ] gw.session_store.rewrite_transcript = MagicMock() captured = {} async def fake_handle_message(event): captured["text"] = event.text return "ok" gw._handle_message = AsyncMock(side_effect=fake_handle_message) await gw._handle_retry_command( MessageEvent(text="/retry", message_type=MessageType.TEXT, source=MagicMock()) ) assert captured["text"] == "real message"