diff --git a/gateway/run.py b/gateway/run.py index b8df4deca..8154b76f1 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -984,8 +984,7 @@ class GatewayRunner: source = event.source # Get existing session key - session_key = f"agent:main:{source.platform.value}:" + \ - (f"dm" if source.chat_type == "dm" else f"{source.chat_type}:{source.chat_id}") + session_key = self.session_store._generate_session_key(source) # Memory flush before reset: load the old transcript and let a # temporary agent save memories before the session is wiped. diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 2f5f4e4a5..979ee6d44 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -1,9 +1,13 @@ """Tests for gateway session management.""" +import json import pytest +from pathlib import Path +from unittest.mock import patch, MagicMock from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig from gateway.session import ( SessionSource, + SessionStore, build_session_context, build_session_context_prompt, ) @@ -199,3 +203,59 @@ class TestBuildSessionContextPrompt: prompt = build_session_context_prompt(ctx) assert "WhatsApp" in prompt or "whatsapp" in prompt.lower() + + +class TestSessionStoreRewriteTranscript: + """Regression: /retry and /undo must persist truncated history to disk.""" + + @pytest.fixture() + def store(self, tmp_path): + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + s = SessionStore(sessions_dir=tmp_path, config=config) + s._db = None # no SQLite for these tests + s._loaded = True + return s + + def test_rewrite_replaces_jsonl(self, store, tmp_path): + session_id = "test_session_1" + # Write initial transcript + for msg in [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": "undo this"}, + {"role": "assistant", "content": "ok"}, + ]: + store.append_to_transcript(session_id, msg) + + # Rewrite with truncated history + store.rewrite_transcript(session_id, [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ]) + + reloaded = store.load_transcript(session_id) + assert len(reloaded) == 2 + assert reloaded[0]["content"] == "hello" + assert reloaded[1]["content"] == "hi" + + def test_rewrite_with_empty_list(self, store): + session_id = "test_session_2" + store.append_to_transcript(session_id, {"role": "user", "content": "hi"}) + + store.rewrite_transcript(session_id, []) + + reloaded = store.load_transcript(session_id) + assert reloaded == [] + + +class TestSessionStoreEntriesAttribute: + """Regression: /reset must access _entries, not _sessions.""" + + def test_entries_attribute_exists(self): + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore(sessions_dir=Path("/tmp"), config=config) + store._loaded = True + assert hasattr(store, "_entries") + assert not hasattr(store, "_sessions")