From b7f8a17c24b66fcc2b6b36c292b58111535fcd8b Mon Sep 17 00:00:00 2001 From: Farukest Date: Sun, 1 Mar 2026 01:12:58 +0300 Subject: [PATCH] fix(gateway): persist transcript changes in /retry, /undo and fix /reset /retry and /undo set session_entry.conversation_history which does not exist on SessionEntry. The truncated history was never written to disk, so the next message reload picked up the full unmodified transcript. Added SessionStore.rewrite_transcript() that persists changes to both the JSONL file and SQLite database, and updated both commands to use it. /reset accessed self.session_store._sessions which does not exist on SessionStore (the correct attribute is _entries). Also replaced the hand-coded session key with _generate_session_key() to fix WhatsApp DM sessions using the wrong key format. Closes #210 --- gateway/run.py | 9 +++--- gateway/session.py | 31 ++++++++++++++++++ tests/gateway/test_session.py | 60 +++++++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 5 deletions(-) diff --git a/gateway/run.py b/gateway/run.py index 4f4a81bad..484d65fec 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -901,13 +901,12 @@ class GatewayRunner: source = event.source # Get existing session key - session_key = f"agent:main:{source.platform.value}:" + \ - (f"dm" if source.chat_type == "dm" else f"{source.chat_type}:{source.chat_id}") + session_key = self.session_store._generate_session_key(source) # Memory flush before reset: load the old transcript and let a # temporary agent save memories before the session is wiped. try: - old_entry = self.session_store._sessions.get(session_key) + old_entry = self.session_store._entries.get(session_key) if old_entry: old_history = self.session_store.load_transcript(old_entry.session_id) if old_history: @@ -1135,7 +1134,7 @@ class GatewayRunner: # Truncate history to before the last user message truncated = history[:last_user_idx] - session_entry.conversation_history = truncated + self.session_store.rewrite_transcript(session_entry.session_id, truncated) # Re-send by creating a fake text event with the old message retry_event = MessageEvent( @@ -1167,7 +1166,7 @@ class GatewayRunner: removed_msg = history[last_user_idx].get("content", "") removed_count = len(history) - last_user_idx - session_entry.conversation_history = history[:last_user_idx] + self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx]) preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\"" diff --git a/gateway/session.py b/gateway/session.py index 65528cdd8..ad4bb3319 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -567,6 +567,37 @@ class SessionStore: with open(transcript_path, "a") as f: f.write(json.dumps(message, ensure_ascii=False) + "\n") + def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None: + """Replace a session's transcript with the given messages.""" + # Rewrite SQLite + if self._db: + try: + self._db._conn.execute( + "DELETE FROM messages WHERE session_id = ?", (session_id,) + ) + self._db._conn.execute( + "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", + (session_id,), + ) + self._db._conn.commit() + for msg in messages: + self._db.append_message( + session_id=session_id, + role=msg.get("role", "unknown"), + content=msg.get("content"), + tool_name=msg.get("tool_name"), + tool_calls=msg.get("tool_calls"), + tool_call_id=msg.get("tool_call_id"), + ) + except Exception as e: + logger.debug("Session DB rewrite failed: %s", e) + + # Rewrite legacy JSONL + transcript_path = self.get_transcript_path(session_id) + with open(transcript_path, "w") as f: + for msg in messages: + f.write(json.dumps(msg, ensure_ascii=False) + "\n") + def load_transcript(self, session_id: str) -> List[Dict[str, Any]]: """Load all messages from a session's transcript.""" # Try SQLite first diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 2f5f4e4a5..979ee6d44 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -1,9 +1,13 @@ """Tests for gateway session management.""" +import json import pytest +from pathlib import Path +from unittest.mock import patch, MagicMock from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig from gateway.session import ( SessionSource, + SessionStore, build_session_context, build_session_context_prompt, ) @@ -199,3 +203,59 @@ class TestBuildSessionContextPrompt: prompt = build_session_context_prompt(ctx) assert "WhatsApp" in prompt or "whatsapp" in prompt.lower() + + +class TestSessionStoreRewriteTranscript: + """Regression: /retry and /undo must persist truncated history to disk.""" + + @pytest.fixture() + def store(self, tmp_path): + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + s = SessionStore(sessions_dir=tmp_path, config=config) + s._db = None # no SQLite for these tests + s._loaded = True + return s + + def test_rewrite_replaces_jsonl(self, store, tmp_path): + session_id = "test_session_1" + # Write initial transcript + for msg in [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": "undo this"}, + {"role": "assistant", "content": "ok"}, + ]: + store.append_to_transcript(session_id, msg) + + # Rewrite with truncated history + store.rewrite_transcript(session_id, [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ]) + + reloaded = store.load_transcript(session_id) + assert len(reloaded) == 2 + assert reloaded[0]["content"] == "hello" + assert reloaded[1]["content"] == "hi" + + def test_rewrite_with_empty_list(self, store): + session_id = "test_session_2" + store.append_to_transcript(session_id, {"role": "user", "content": "hi"}) + + store.rewrite_transcript(session_id, []) + + reloaded = store.load_transcript(session_id) + assert reloaded == [] + + +class TestSessionStoreEntriesAttribute: + """Regression: /reset must access _entries, not _sessions.""" + + def test_entries_attribute_exists(self): + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore(sessions_dir=Path("/tmp"), config=config) + store._loaded = True + assert hasattr(store, "_entries") + assert not hasattr(store, "_sessions")