diff --git a/gateway/run.py b/gateway/run.py index 6f043d448..b8df4deca 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -990,7 +990,7 @@ class GatewayRunner: # Memory flush before reset: load the old transcript and let a # temporary agent save memories before the session is wiped. try: - old_entry = self.session_store._sessions.get(session_key) + old_entry = self.session_store._entries.get(session_key) if old_entry: old_history = self.session_store.load_transcript(old_entry.session_id) if old_history: @@ -1222,9 +1222,9 @@ class GatewayRunner: if not last_user_msg: return "No previous message to retry." - # Truncate history to before the last user message + # Truncate history to before the last user message and persist truncated = history[:last_user_idx] - session_entry.conversation_history = truncated + self.session_store.rewrite_transcript(session_entry.session_id, truncated) # Re-send by creating a fake text event with the old message retry_event = MessageEvent( @@ -1256,7 +1256,7 @@ class GatewayRunner: removed_msg = history[last_user_idx].get("content", "") removed_count = len(history) - last_user_idx - session_entry.conversation_history = history[:last_user_idx] + self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx]) preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\"" @@ -1330,7 +1330,7 @@ class GatewayRunner: lambda: tmp_agent._compress_context(msgs, "", approx_tokens=approx_tokens), ) - session_entry.conversation_history = compressed + self.session_store.rewrite_transcript(session_entry.session_id, compressed) new_count = len(compressed) new_tokens = estimate_messages_tokens_rough(compressed) diff --git a/gateway/session.py b/gateway/session.py index 65528cdd8..c93aba24a 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -567,6 +567,34 @@ class SessionStore: with open(transcript_path, "a") as f: f.write(json.dumps(message, ensure_ascii=False) + "\n") + def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None: + """Replace the entire transcript for a session with new messages. + + Used by /retry, /undo, and /compress to persist modified conversation history. + Rewrites both SQLite and legacy JSONL storage. + """ + # SQLite: clear old messages and re-insert + if self._db: + try: + self._db.clear_messages(session_id) + for msg in messages: + self._db.append_message( + session_id=session_id, + role=msg.get("role", "unknown"), + content=msg.get("content"), + tool_name=msg.get("tool_name"), + tool_calls=msg.get("tool_calls"), + tool_call_id=msg.get("tool_call_id"), + ) + except Exception as e: + logger.debug("Failed to rewrite transcript in DB: %s", e) + + # JSONL: overwrite the file + transcript_path = self.get_transcript_path(session_id) + with open(transcript_path, "w") as f: + for msg in messages: + f.write(json.dumps(msg, ensure_ascii=False) + "\n") + def load_transcript(self, session_id: str) -> List[Dict[str, Any]]: """Load all messages from a session's transcript.""" # Try SQLite first diff --git a/hermes_state.py b/hermes_state.py index ebb3f1dd7..1d1f951c0 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -476,6 +476,17 @@ class SessionDB: results.append({**session, "messages": messages}) return results + def clear_messages(self, session_id: str) -> None: + """Delete all messages for a session and reset its counters.""" + self._conn.execute( + "DELETE FROM messages WHERE session_id = ?", (session_id,) + ) + self._conn.execute( + "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", + (session_id,), + ) + self._conn.commit() + def delete_session(self, session_id: str) -> bool: """Delete a session and all its messages. Returns True if found.""" cursor = self._conn.execute(