#!/usr/bin/env python3 """ SQLite State Store for Hermes Agent. Provides persistent session storage with FTS5 full-text search, replacing the per-session JSONL file approach. Stores session metadata, full message history, and model configuration for CLI and gateway sessions. Key design decisions: - WAL mode for concurrent readers + one writer (gateway multi-platform) - FTS5 virtual table for fast text search across all session messages - Compression-triggered session splitting via parent_session_id chains - Batch runner and RL trajectories are NOT stored here (separate systems) - Session source tagging ('cli', 'telegram', 'discord', etc.) for filtering - Connection pooling for concurrent reads with dedicated write connection - Write queue with batching to reduce lock contention """ import json import logging import os import queue import re import sqlite3 import threading import time from pathlib import Path from hermes_constants import get_hermes_home from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple from dataclasses import dataclass, field from contextlib import contextmanager logger = logging.getLogger(__name__) T = TypeVar("T") DEFAULT_DB_PATH = get_hermes_home() / "state.db" SCHEMA_VERSION = 6 SCHEMA_SQL = """ CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER NOT NULL ); CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, source TEXT NOT NULL, user_id TEXT, model TEXT, model_config TEXT, system_prompt TEXT, parent_session_id TEXT, started_at REAL NOT NULL, ended_at REAL, end_reason TEXT, message_count INTEGER DEFAULT 0, tool_call_count INTEGER DEFAULT 0, input_tokens INTEGER DEFAULT 0, output_tokens INTEGER DEFAULT 0, cache_read_tokens INTEGER DEFAULT 0, cache_write_tokens INTEGER DEFAULT 0, reasoning_tokens INTEGER DEFAULT 0, billing_provider TEXT, billing_base_url TEXT, billing_mode TEXT, estimated_cost_usd REAL, actual_cost_usd REAL, cost_status TEXT, cost_source TEXT, pricing_version TEXT, title TEXT, FOREIGN KEY (parent_session_id) REFERENCES sessions(id) ); CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT NOT NULL REFERENCES sessions(id), role TEXT NOT NULL, content TEXT, tool_call_id TEXT, tool_calls TEXT, tool_name TEXT, timestamp REAL NOT NULL, token_count INTEGER, finish_reason TEXT, reasoning TEXT, reasoning_details TEXT, codex_reasoning_items TEXT ); CREATE INDEX IF NOT EXISTS idx_sessions_source ON sessions(source); CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id); CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC); CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp); """ FTS_SQL = """ CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5( content, content=messages, content_rowid=id ); CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN INSERT INTO messages_fts(rowid, content) VALUES (new.id, new.content); END; CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN INSERT INTO messages_fts(messages_fts, rowid, content) VALUES('delete', old.id, old.content); END; CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN INSERT INTO messages_fts(messages_fts, rowid, content) VALUES('delete', old.id, old.content); INSERT INTO messages_fts(rowid, content) VALUES (new.id, new.content); END; """ @dataclass class WriteOperation: """Represents a single write operation to be batched.""" fn: Callable[[sqlite3.Connection], Any] result_queue: queue.Queue = field(default_factory=queue.Queue) error: Optional[Exception] = None class ConnectionPool: """ Manages a pool of SQLite connections for concurrent reads. Uses separate read and write connections: - Write connection: dedicated single connection with exclusive access - Read connections: pool of connections for concurrent reads (WAL mode allows this) """ def __init__( self, db_path: Path, pool_size: int = 5, timeout: float = 30.0, ): self.db_path = db_path self.pool_size = pool_size self.timeout = timeout # Write connection (dedicated) self._write_conn: Optional[sqlite3.Connection] = None self._write_lock = threading.Lock() # Read connection pool self._read_pool: queue.Queue[sqlite3.Connection] = queue.Queue(maxsize=pool_size) self._read_pool_lock = threading.Lock() self._connections: List[sqlite3.Connection] = [] self._initialized = False self._closed = False def initialize(self) -> None: """Initialize the connection pool.""" if self._initialized: return # Create write connection self._write_conn = self._create_connection() self._connections.append(self._write_conn) # Create read connections for _ in range(self.pool_size): conn = self._create_connection() self._read_pool.put(conn) self._connections.append(conn) self._initialized = True logger.debug(f"Connection pool initialized with {self.pool_size} read connections") def _create_connection(self) -> sqlite3.Connection: """Create a new SQLite connection with proper settings.""" conn = sqlite3.connect( str(self.db_path), check_same_thread=False, timeout=self.timeout, isolation_level=None, ) conn.row_factory = sqlite3.Row # WAL mode is set per-connection, but once set on the database it persists conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA foreign_keys=ON") # WAL mode optimizations conn.execute("PRAGMA synchronous=NORMAL") # Faster than FULL, still safe with WAL conn.execute("PRAGMA temp_store=MEMORY") conn.execute("PRAGMA mmap_size=268435456") # 256MB mmap return conn @contextmanager def get_write_connection(self): """Get the dedicated write connection.""" if self._closed: raise RuntimeError("Connection pool is closed") if not self._initialized: self.initialize() with self._write_lock: yield self._write_conn @contextmanager def get_read_connection(self): """Get a read connection from the pool.""" if self._closed: raise RuntimeError("Connection pool is closed") if not self._initialized: self.initialize() conn = None try: conn = self._read_pool.get(timeout=self.timeout) yield conn finally: if conn is not None: self._read_pool.put(conn) def close(self) -> None: """Close all connections in the pool.""" self._closed = True # Close write connection if self._write_conn: try: self._write_conn.execute("PRAGMA wal_checkpoint(PASSIVE)") except Exception: pass self._write_conn.close() self._write_conn = None # Close all read connections with self._read_pool_lock: while not self._read_pool.empty(): try: conn = self._read_pool.get_nowait() conn.close() except queue.Empty: break for conn in self._connections: try: conn.close() except Exception: pass self._connections.clear() class WriteBatcher: """ Batches write operations and flushes them periodically or when threshold is reached. This reduces lock contention by: 1. Accumulating writes in a queue 2. Executing them in a single transaction when batching conditions are met 3. Allowing concurrent reads while writes are being processed """ def __init__( self, pool: ConnectionPool, max_batch_size: int = 50, max_wait_ms: float = 100.0, enable_batching: bool = True, ): self.pool = pool self.max_batch_size = max_batch_size self.max_wait_ms = max_wait_ms self.enable_batching = enable_batching self._write_queue: queue.Queue[WriteOperation] = queue.Queue() self._lock = threading.Lock() self._flush_event = threading.Event() self._shutdown = False # Statistics self._stats_lock = threading.Lock() self._total_writes = 0 self._batched_writes = 0 self._batch_count = 0 # Start background flusher thread if batching is enabled if enable_batching: self._flusher_thread = threading.Thread(target=self._flush_loop, daemon=True) self._flusher_thread.start() def execute(self, fn: Callable[[sqlite3.Connection], T]) -> T: """Execute a write operation, either batched or immediate.""" if not self.enable_batching: # Direct execution without batching return self._execute_immediate(fn) # Create operation and queue it op = WriteOperation(fn=fn) self._write_queue.put(op) # Signal potential batch flush if self._write_queue.qsize() >= self.max_batch_size: self._flush_event.set() # Wait for result result = op.result_queue.get() if isinstance(result, Exception): raise result return result def _execute_immediate(self, fn: Callable[[sqlite3.Connection], T]) -> T: """Execute a write immediately without batching.""" with self.pool.get_write_connection() as conn: conn.execute("BEGIN IMMEDIATE") try: result = fn(conn) conn.commit() return result except Exception: conn.rollback() raise def _flush_loop(self) -> None: """Background thread that periodically flushes the write queue.""" while not self._shutdown: # Wait for flush signal or timeout self._flush_event.wait(timeout=self.max_wait_ms / 1000.0) self._flush_event.clear() if self._shutdown: break # Flush if queue has items if not self._write_queue.empty(): self._flush_batch() def _flush_batch(self) -> None: """Flush pending write operations as a batch.""" # Collect operations from queue operations: List[WriteOperation] = [] with self._lock: while len(operations) < self.max_batch_size and not self._write_queue.empty(): try: op = self._write_queue.get_nowait() operations.append(op) except queue.Empty: break if not operations: return # Execute all operations in a single transaction with self.pool.get_write_connection() as conn: try: conn.execute("BEGIN IMMEDIATE") for op in operations: try: result = op.fn(conn) op.result_queue.put(result) except Exception as e: op.result_queue.put(e) conn.commit() # Update stats with self._stats_lock: self._total_writes += len(operations) self._batched_writes += len(operations) self._batch_count += 1 # Periodic checkpoint if self._batch_count % 10 == 0: self._try_checkpoint(conn) except Exception as e: conn.rollback() # Propagate error to all pending operations for op in operations: op.result_queue.put(e) def _try_checkpoint(self, conn: sqlite3.Connection) -> None: """Attempt a passive WAL checkpoint.""" try: result = conn.execute("PRAGMA wal_checkpoint(PASSIVE)").fetchone() if result and result[1] > 0: logger.debug(f"WAL checkpoint: {result[2]}/{result[1]} pages") except Exception: pass def flush(self) -> None: """Force flush all pending writes.""" if self.enable_batching: self._flush_event.set() # Wait for queue to empty while not self._write_queue.empty(): time.sleep(0.01) def shutdown(self) -> None: """Shutdown the batcher and flush remaining writes.""" self._shutdown = True self._flush_event.set() if self.enable_batching: self.flush() if hasattr(self, '_flusher_thread') and self._flusher_thread.is_alive(): self._flusher_thread.join(timeout=5.0) def get_stats(self) -> Dict[str, int]: """Get batcher statistics.""" with self._stats_lock: return { "total_writes": self._total_writes, "batched_writes": self._batched_writes, "batch_count": self._batch_count, "pending_writes": self._write_queue.qsize(), } class SessionDB: """ SQLite-backed session storage with FTS5 search. Optimized for high concurrency with: - Connection pooling for concurrent reads - Dedicated write connection with batching - WAL mode for maximum concurrency - Separate read/write paths to minimize contention """ # ── Configuration ── # Connection pool size (number of concurrent read connections) DEFAULT_POOL_SIZE = 5 # Write batching configuration DEFAULT_BATCH_SIZE = 50 DEFAULT_BATCH_WAIT_MS = 100.0 # Retry configuration (simplified with batching) _WRITE_MAX_RETRIES = 3 _WRITE_RETRY_MIN_S = 0.010 _WRITE_RETRY_MAX_S = 0.050 # Checkpoint every N batches _CHECKPOINT_EVERY_N_BATCHES = 10 def __init__( self, db_path: Path = None, pool_size: int = None, batch_size: int = None, batch_wait_ms: float = None, enable_batching: bool = True, ): """ Initialize SessionDB with connection pooling and write batching. Args: db_path: Path to the SQLite database file pool_size: Number of connections in the read pool (default: 5) batch_size: Maximum number of writes to batch together (default: 50) batch_wait_ms: Maximum time to wait before flushing batch (default: 100ms) enable_batching: Whether to enable write batching (default: True) """ self.db_path = db_path or DEFAULT_DB_PATH self.db_path.parent.mkdir(parents=True, exist_ok=True) # Initialize connection pool self._pool = ConnectionPool( db_path=self.db_path, pool_size=pool_size or self.DEFAULT_POOL_SIZE, ) self._pool.initialize() # Initialize write batcher self._write_batcher = WriteBatcher( pool=self._pool, max_batch_size=batch_size or self.DEFAULT_BATCH_SIZE, max_wait_ms=batch_wait_ms or self.DEFAULT_BATCH_WAIT_MS, enable_batching=enable_batching, ) # Initialize schema self._init_schema() # Write count for checkpointing self._write_count = 0 self._stats_lock = threading.Lock() # ── Core write helper ── def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T: """Execute a write operation through the batcher. The batcher accumulates writes and executes them in batches, reducing lock contention and improving throughput. """ last_err: Optional[Exception] = None for attempt in range(self._WRITE_MAX_RETRIES): try: result = self._write_batcher.execute(fn) with self._stats_lock: self._write_count += 1 return result except sqlite3.OperationalError as exc: err_msg = str(exc).lower() if "locked" in err_msg or "busy" in err_msg: last_err = exc if attempt < self._WRITE_MAX_RETRIES - 1: # Shorter jitter since we're using batching jitter = random.uniform( self._WRITE_RETRY_MIN_S, self._WRITE_RETRY_MAX_S, ) time.sleep(jitter) continue raise raise last_err or sqlite3.OperationalError( "database is locked after max retries" ) def _execute_read(self, fn: Callable[[sqlite3.Connection], T]) -> T: """Execute a read operation using a connection from the pool.""" with self._pool.get_read_connection() as conn: return fn(conn) def _try_wal_checkpoint(self) -> None: """Best-effort PASSIVE WAL checkpoint.""" try: with self._pool.get_write_connection() as conn: result = conn.execute("PRAGMA wal_checkpoint(PASSIVE)").fetchone() if result and result[1] > 0: logger.debug( "WAL checkpoint: %d/%d pages checkpointed", result[2], result[1], ) except Exception: pass def close(self): """Close the database connection pool and flush pending writes.""" # Shutdown batcher (flushes pending writes) self._write_batcher.shutdown() # Close connection pool self._pool.close() def _init_schema(self): """Create tables and FTS if they don't exist, run migrations.""" def _do_schema(conn): cursor = conn.cursor() cursor.executescript(SCHEMA_SQL) # Check schema version and run migrations cursor.execute("SELECT version FROM schema_version LIMIT 1") row = cursor.fetchone() if row is None: cursor.execute("INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,)) else: current_version = row["version"] if isinstance(row, sqlite3.Row) else row[0] self._run_migrations(cursor, current_version) # Unique title index try: cursor.execute( "CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique " "ON sessions(title) WHERE title IS NOT NULL" ) except sqlite3.OperationalError: pass # FTS5 setup try: cursor.execute("SELECT * FROM messages_fts LIMIT 0") except sqlite3.OperationalError: cursor.executescript(FTS_SQL) conn.commit() with self._pool.get_write_connection() as conn: _do_schema(conn) def _run_migrations(self, cursor, current_version: int) -> None: """Run database migrations.""" if current_version < 2: try: cursor.execute("ALTER TABLE messages ADD COLUMN finish_reason TEXT") except sqlite3.OperationalError: pass cursor.execute("UPDATE schema_version SET version = 2") if current_version < 3: try: cursor.execute("ALTER TABLE sessions ADD COLUMN title TEXT") except sqlite3.OperationalError: pass cursor.execute("UPDATE schema_version SET version = 3") if current_version < 4: try: cursor.execute( "CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique " "ON sessions(title) WHERE title IS NOT NULL" ) except sqlite3.OperationalError: pass cursor.execute("UPDATE schema_version SET version = 4") if current_version < 5: new_columns = [ ("cache_read_tokens", "INTEGER DEFAULT 0"), ("cache_write_tokens", "INTEGER DEFAULT 0"), ("reasoning_tokens", "INTEGER DEFAULT 0"), ("billing_provider", "TEXT"), ("billing_base_url", "TEXT"), ("billing_mode", "TEXT"), ("estimated_cost_usd", "REAL"), ("actual_cost_usd", "REAL"), ("cost_status", "TEXT"), ("cost_source", "TEXT"), ("pricing_version", "TEXT"), ] for name, column_type in new_columns: try: safe_name = name.replace('"', '""') cursor.execute(f'ALTER TABLE sessions ADD COLUMN "{safe_name}" {column_type}') except sqlite3.OperationalError: pass cursor.execute("UPDATE schema_version SET version = 5") if current_version < 6: for col_name, col_type in [ ("reasoning", "TEXT"), ("reasoning_details", "TEXT"), ("codex_reasoning_items", "TEXT"), ]: try: safe = col_name.replace('"', '""') cursor.execute( f'ALTER TABLE messages ADD COLUMN "{safe}" {col_type}' ) except sqlite3.OperationalError: pass cursor.execute("UPDATE schema_version SET version = 6") # ========================================================================= # Session lifecycle # ========================================================================= def create_session( self, session_id: str, source: str, model: str = None, model_config: Dict[str, Any] = None, system_prompt: str = None, user_id: str = None, parent_session_id: str = None, ) -> str: """Create a new session record. Returns the session_id.""" def _do(conn): conn.execute( """INSERT OR IGNORE INTO sessions (id, source, user_id, model, model_config, system_prompt, parent_session_id, started_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", ( session_id, source, user_id, model, json.dumps(model_config) if model_config else None, system_prompt, parent_session_id, time.time(), ), ) self._execute_write(_do) return session_id def end_session(self, session_id: str, end_reason: str) -> None: """Mark a session as ended.""" def _do(conn): conn.execute( "UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?", (time.time(), end_reason, session_id), ) self._execute_write(_do) def reopen_session(self, session_id: str) -> None: """Clear ended_at/end_reason so a session can be resumed.""" def _do(conn): conn.execute( "UPDATE sessions SET ended_at = NULL, end_reason = NULL WHERE id = ?", (session_id,), ) self._execute_write(_do) def update_system_prompt(self, session_id: str, system_prompt: str) -> None: """Store the full assembled system prompt snapshot.""" def _do(conn): conn.execute( "UPDATE sessions SET system_prompt = ? WHERE id = ?", (system_prompt, session_id), ) self._execute_write(_do) def update_token_counts( self, session_id: str, input_tokens: int = 0, output_tokens: int = 0, model: str = None, cache_read_tokens: int = 0, cache_write_tokens: int = 0, reasoning_tokens: int = 0, estimated_cost_usd: Optional[float] = None, actual_cost_usd: Optional[float] = None, cost_status: Optional[str] = None, cost_source: Optional[str] = None, pricing_version: Optional[str] = None, billing_provider: Optional[str] = None, billing_base_url: Optional[str] = None, billing_mode: Optional[str] = None, absolute: bool = False, ) -> None: """Update token counters and backfill model if not already set. When *absolute* is False (default), values are **incremented** — use this for per-API-call deltas (CLI path). When *absolute* is True, values are **set directly** — use this when the caller already holds cumulative totals (gateway path, where the cached agent accumulates across messages). """ if absolute: sql = """UPDATE sessions SET input_tokens=?, output_tokens=?, cache_read_tokens=?, cache_write_tokens=?, reasoning_tokens=?, estimated_cost_usd = COALESCE(?, 0), actual_cost_usd = CASE WHEN ? IS NULL THEN actual_cost_usd ELSE ? END, cost_status = COALESCE(?, cost_status), cost_source = COALESCE(?, cost_source), pricing_version = COALESCE(?, pricing_version), billing_provider = COALESCE(billing_provider, ?), billing_base_url = COALESCE(billing_base_url, ?), billing_mode = COALESCE(billing_mode, ?), model = COALESCE(model, ?) WHERE id = ?""" else: sql = """UPDATE sessions SET input_tokens=input_tokens + ?, output_tokens=output_tokens + ?, cache_read_tokens=cache_read_tokens + ?, cache_write_tokens=cache_write_tokens + ?, reasoning_tokens=reasoning_tokens + ?, estimated_cost_usd = COALESCE(estimated_cost_usd, 0) + COALESCE(?, 0), actual_cost_usd = CASE WHEN ? IS NULL THEN actual_cost_usd ELSE COALESCE(actual_cost_usd, 0) + ? END, cost_status = COALESCE(?, cost_status), cost_source = COALESCE(?, cost_source), pricing_version = COALESCE(?, pricing_version), billing_provider = COALESCE(billing_provider, ?), billing_base_url = COALESCE(billing_base_url, ?), billing_mode = COALESCE(billing_mode, ?), model = COALESCE(model, ?) WHERE id = ?""" params = ( input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, reasoning_tokens, estimated_cost_usd, actual_cost_usd, actual_cost_usd, cost_status, cost_source, pricing_version, billing_provider, billing_base_url, billing_mode, model, session_id, ) def _do(conn): conn.execute(sql, params) self._execute_write(_do) def ensure_session( self, session_id: str, source: str = "unknown", model: str = None, ) -> None: """Ensure a session row exists, creating it with minimal metadata if absent. Used by _flush_messages_to_session_db to recover from a failed create_session() call (e.g. transient SQLite lock at agent startup). INSERT OR IGNORE is safe to call even when the row already exists. """ def _do(conn): conn.execute( """INSERT OR IGNORE INTO sessions (id, source, model, started_at) VALUES (?, ?, ?, ?)""", (session_id, source, model, time.time()), ) self._execute_write(_do) def set_token_counts( self, session_id: str, input_tokens: int = 0, output_tokens: int = 0, model: str = None, cache_read_tokens: int = 0, cache_write_tokens: int = 0, reasoning_tokens: int = 0, estimated_cost_usd: Optional[float] = None, actual_cost_usd: Optional[float] = None, cost_status: Optional[str] = None, cost_source: Optional[str] = None, pricing_version: Optional[str] = None, billing_provider: Optional[str] = None, billing_base_url: Optional[str] = None, billing_mode: Optional[str] = None, ) -> None: """Set token counters to absolute values (not increment). Use this when the caller provides cumulative totals from a completed conversation run (e.g. the gateway, where the cached agent's session_prompt_tokens already reflects the running total). """ def _do(conn): conn.execute( """UPDATE sessions SET input_tokens=?, output_tokens=?, cache_read_tokens=?, cache_write_tokens=?, reasoning_tokens=?, estimated_cost_usd = ?, actual_cost_usd = CASE WHEN ? IS NULL THEN actual_cost_usd ELSE ? END, cost_status = COALESCE(?, cost_status), cost_source = COALESCE(?, cost_source), pricing_version = COALESCE(?, pricing_version), billing_provider = COALESCE(billing_provider, ?), billing_base_url = COALESCE(billing_base_url, ?), billing_mode = COALESCE(billing_mode, ?), model = COALESCE(model, ?) WHERE id = ?""", ( input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, reasoning_tokens, estimated_cost_usd, actual_cost_usd, actual_cost_usd, cost_status, cost_source, pricing_version, billing_provider, billing_base_url, billing_mode, model, session_id, ), ) self._execute_write(_do) def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """Get a session by ID.""" def _do(conn): cursor = conn.execute( "SELECT * FROM sessions WHERE id = ?", (session_id,) ) row = cursor.fetchone() return dict(row) if row else None return self._execute_read(_do) def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]: """Resolve an exact or uniquely prefixed session ID to the full ID. Returns the exact ID when it exists. Otherwise treats the input as a prefix and returns the single matching session ID if the prefix is unambiguous. Returns None for no matches or ambiguous prefixes. """ exact = self.get_session(session_id_or_prefix) if exact: return exact["id"] escaped = ( session_id_or_prefix .replace("\\", "\\\\") .replace("%", "\\%") .replace("_", "\\_") ) def _do(conn): cursor = conn.execute( "SELECT id FROM sessions WHERE id LIKE ? ESCAPE '\\\\' ORDER BY started_at DESC LIMIT 2", (f"{escaped}%",), ) matches = [row["id"] for row in cursor.fetchall()] if len(matches) == 1: return matches[0] return None return self._execute_read(_do) # Maximum length for session titles MAX_TITLE_LENGTH = 100 @staticmethod def sanitize_title(title: Optional[str]) -> Optional[str]: """Validate and sanitize a session title. - Strips leading/trailing whitespace - Removes ASCII control characters (0x00-0x1F, 0x7F) and problematic Unicode control chars (zero-width, RTL/LTR overrides, etc.) - Collapses internal whitespace runs to single spaces - Normalizes empty/whitespace-only strings to None - Enforces MAX_TITLE_LENGTH Returns the cleaned title string or None. Raises ValueError if the title exceeds MAX_TITLE_LENGTH after cleaning. """ if not title: return None # Remove ASCII control characters (0x00-0x1F, 0x7F) but keep # whitespace chars (\t=0x09, \n=0x0A, \r=0x0D) so they can be # normalized to spaces by the whitespace collapsing step below cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title) # Remove problematic Unicode control characters: # - Zero-width chars (U+200B-U+200F, U+FEFF) # - Directional overrides (U+202A-U+202E, U+2066-U+2069) # - Object replacement (U+FFFC), interlinear annotation (U+FFF9-U+FFFB) cleaned = re.sub( r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]', '', cleaned, ) # Collapse internal whitespace runs and strip cleaned = re.sub(r'\s+', ' ', cleaned).strip() if not cleaned: return None if len(cleaned) > SessionDB.MAX_TITLE_LENGTH: raise ValueError( f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})" ) return cleaned def set_session_title(self, session_id: str, title: str) -> bool: """Set or update a session's title. Returns True if session was found and title was set. Raises ValueError if title is already in use by another session, or if the title fails validation (too long, invalid characters). Empty/whitespace-only strings are normalized to None (clearing the title). """ title = self.sanitize_title(title) def _do(conn): if title: # Check uniqueness (allow the same session to keep its own title) cursor = conn.execute( "SELECT id FROM sessions WHERE title = ? AND id != ?", (title, session_id), ) conflict = cursor.fetchone() if conflict: raise ValueError( f"Title '{title}' is already in use by session {conflict['id']}" ) cursor = conn.execute( "UPDATE sessions SET title = ? WHERE id = ?", (title, session_id), ) return cursor.rowcount rowcount = self._execute_write(_do) return rowcount > 0 def get_session_title(self, session_id: str) -> Optional[str]: """Get the title for a session, or None.""" def _do(conn): cursor = conn.execute( "SELECT title FROM sessions WHERE id = ?", (session_id,) ) row = cursor.fetchone() return row["title"] if row else None return self._execute_read(_do) def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]: """Look up a session by exact title. Returns session dict or None.""" def _do(conn): cursor = conn.execute( "SELECT * FROM sessions WHERE title = ?", (title,) ) row = cursor.fetchone() return dict(row) if row else None return self._execute_read(_do) def resolve_session_by_title(self, title: str) -> Optional[str]: """Resolve a title to a session ID, preferring the latest in a lineage. If the exact title exists, returns that session's ID. If not, searches for "title #N" variants and returns the latest one. If the exact title exists AND numbered variants exist, returns the latest numbered variant (the most recent continuation). """ # First try exact match exact = self.get_session_by_title(title) # Also search for numbered variants: "title #2", "title #3", etc. # Escape SQL LIKE wildcards (%, _) in the title to prevent false matches escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") def _do(conn): cursor = conn.execute( "SELECT id, title, started_at FROM sessions " "WHERE title LIKE ? ESCAPE '\\\\' ORDER BY started_at DESC", (f"{escaped} #%",), ) numbered = cursor.fetchall() return numbered numbered = self._execute_read(_do) if numbered: # Return the most recent numbered variant return numbered[0]["id"] elif exact: return exact["id"] return None def get_next_title_in_lineage(self, base_title: str) -> str: """Generate the next title in a lineage (e.g., "my session" → "my session #2"). Strips any existing " #N" suffix to find the base name, then finds the highest existing number and increments. """ # Strip existing #N suffix to find the true base match = re.match(r'^(.*?) #(\d+)$', base_title) if match: base = match.group(1) else: base = base_title # Find all existing numbered variants # Escape SQL LIKE wildcards (%, _) in the base to prevent false matches escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") def _do(conn): cursor = conn.execute( "SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\\\'", (base, f"{escaped} #%"), ) return [row["title"] for row in cursor.fetchall()] existing = self._execute_read(_do) if not existing: return base # No conflict, use the base name as-is # Find the highest number max_num = 1 # The unnumbered original counts as #1 for t in existing: m = re.match(r'^.* #(\d+)$', t) if m: max_num = max(max_num, int(m.group(1))) return f"{base} #{max_num + 1}" def list_sessions_rich( self, source: str = None, exclude_sources: List[str] = None, limit: int = 20, offset: int = 0, ) -> List[Dict[str, Any]]: """List sessions with preview (first user message) and last active timestamp. Returns dicts with keys: id, source, model, title, started_at, ended_at, message_count, preview (first 60 chars of first user message), last_active (timestamp of last message). Uses a single query with correlated subqueries instead of N+2 queries. """ where_clauses = [] params = [] if source: where_clauses.append("s.source = ?") params.append(source) if exclude_sources: placeholders = ",".join("?" for _ in exclude_sources) where_clauses.append(f"s.source NOT IN ({placeholders})") params.extend(exclude_sources) where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" query = f""" SELECT s.*, COALESCE( (SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63) FROM messages m WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL ORDER BY m.timestamp, m.id LIMIT 1), '' ) AS _preview_raw, COALESCE( (SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id), s.started_at ) AS last_active FROM sessions s {where_sql} ORDER BY s.started_at DESC LIMIT ? OFFSET ? """ params.extend([limit, offset]) def _do(conn): cursor = conn.execute(query, params) rows = cursor.fetchall() sessions = [] for row in rows: s = dict(row) # Build the preview from the raw substring raw = s.pop("_preview_raw", "").strip() if raw: text = raw[:60] s["preview"] = text + ("..." if len(raw) > 60 else "") else: s["preview"] = "" sessions.append(s) return sessions return self._execute_read(_do) # ========================================================================= # Message storage # ========================================================================= def append_message( self, session_id: str, role: str, content: str = None, tool_name: str = None, tool_calls: Any = None, tool_call_id: str = None, token_count: int = None, finish_reason: str = None, reasoning: str = None, reasoning_details: Any = None, codex_reasoning_items: Any = None, ) -> int: """ Append a message to a session. Returns the message row ID. Also increments the session's message_count (and tool_call_count if role is 'tool' or tool_calls is present). """ # Serialize structured fields to JSON before entering the write txn reasoning_details_json = ( json.dumps(reasoning_details) if reasoning_details else None ) codex_items_json = ( json.dumps(codex_reasoning_items) if codex_reasoning_items else None ) tool_calls_json = json.dumps(tool_calls) if tool_calls else None # Pre-compute tool call count num_tool_calls = 0 if tool_calls is not None: num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1 def _do(conn): cursor = conn.execute( """INSERT INTO messages (session_id, role, content, tool_call_id, tool_calls, tool_name, timestamp, token_count, finish_reason, reasoning, reasoning_details, codex_reasoning_items) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( session_id, role, content, tool_call_id, tool_calls_json, tool_name, time.time(), token_count, finish_reason, reasoning, reasoning_details_json, codex_items_json, ), ) msg_id = cursor.lastrowid # Update counters if num_tool_calls > 0: conn.execute( """UPDATE sessions SET message_count = message_count + 1, tool_call_count = tool_call_count + ? WHERE id = ?""", (num_tool_calls, session_id), ) else: conn.execute( "UPDATE sessions SET message_count = message_count + 1 WHERE id = ?", (session_id,), ) return msg_id return self._execute_write(_do) def get_messages(self, session_id: str) -> List[Dict[str, Any]]: """Load all messages for a session, ordered by timestamp.""" def _do(conn): cursor = conn.execute( "SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id", (session_id,), ) rows = cursor.fetchall() result = [] for row in rows: msg = dict(row) if msg.get("tool_calls"): try: msg["tool_calls"] = json.loads(msg["tool_calls"]) except (json.JSONDecodeError, TypeError): pass result.append(msg) return result return self._execute_read(_do) def get_messages_as_conversation(self, session_id: str) -> List[Dict[str, Any]]: """ Load messages in the OpenAI conversation format (role + content dicts). Used by the gateway to restore conversation history. """ def _do(conn): cursor = conn.execute( "SELECT role, content, tool_call_id, tool_calls, tool_name, " "reasoning, reasoning_details, codex_reasoning_items " "FROM messages WHERE session_id = ? ORDER BY timestamp, id", (session_id,), ) rows = cursor.fetchall() messages = [] for row in rows: msg = {"role": row["role"], "content": row["content"]} if row["tool_call_id"]: msg["tool_call_id"] = row["tool_call_id"] if row["tool_name"]: msg["tool_name"] = row["tool_name"] if row["tool_calls"]: try: msg["tool_calls"] = json.loads(row["tool_calls"]) except (json.JSONDecodeError, TypeError): pass # Restore reasoning fields on assistant messages so providers # that replay reasoning (OpenRouter, OpenAI, Nous) receive # coherent multi-turn reasoning context. if row["role"] == "assistant": if row["reasoning"]: msg["reasoning"] = row["reasoning"] if row["reasoning_details"]: try: msg["reasoning_details"] = json.loads(row["reasoning_details"]) except (json.JSONDecodeError, TypeError): pass if row["codex_reasoning_items"]: try: msg["codex_reasoning_items"] = json.loads(row["codex_reasoning_items"]) except (json.JSONDecodeError, TypeError): pass messages.append(msg) return messages return self._execute_read(_do) # ========================================================================= # Search # ========================================================================= @staticmethod def _sanitize_fts5_query(query: str) -> str: """Sanitize user input for safe use in FTS5 MATCH queries. FTS5 has its own query syntax where characters like ``"``, ``(``, ``)``, ``+``, ``*``, ``{``, ``}`` and bare boolean operators (``AND``, ``OR``, ``NOT``) have special meaning. Passing raw user input directly to MATCH can cause ``sqlite3.OperationalError``. Strategy: - Preserve properly paired quoted phrases (``"exact phrase"``) - Strip unmatched FTS5-special characters that would cause errors - Wrap unquoted hyphenated terms in quotes so FTS5 matches them as exact phrases instead of splitting on the hyphen """ # Step 1: Extract balanced double-quoted phrases and protect them # from further processing via numbered placeholders. _quoted_parts: list = [] def _preserve_quoted(m: re.Match) -> str: _quoted_parts.append(m.group(0)) return f"\x00Q{len(_quoted_parts) - 1}\x00" sanitized = re.sub(r'"[^"]*"', _preserve_quoted, query) # Step 2: Strip remaining (unmatched) FTS5-special characters sanitized = re.sub(r'[+{}()\"^]', " ", sanitized) # Step 3: Collapse repeated * (e.g. "***") into a single one, # and remove leading * (prefix-only needs at least one char before *) sanitized = re.sub(r"\*+", "*", sanitized) sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized) # Step 4: Remove dangling boolean operators at start/end that would # cause syntax errors (e.g. "hello AND" or "OR world") sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip()) sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip()) # Step 5: Wrap unquoted hyphenated terms (e.g. ``chat-send``) in # double quotes. FTS5's tokenizer splits on hyphens, turning # ``chat-send`` into ``chat AND send``. Quoting preserves the # intended phrase match. sanitized = re.sub(r"\b(\w+(?:-\w+)+)\b", r'"\1"', sanitized) # Step 6: Restore preserved quoted phrases for i, quoted in enumerate(_quoted_parts): sanitized = sanitized.replace(f"\x00Q{i}\x00", quoted) return sanitized.strip() def search_messages( self, query: str, source_filter: List[str] = None, exclude_sources: List[str] = None, role_filter: List[str] = None, limit: int = 20, offset: int = 0, ) -> List[Dict[str, Any]]: """ Full-text search across session messages using FTS5. Supports FTS5 query syntax: - Simple keywords: "docker deployment" - Phrases: '"exact phrase"' - Boolean: "docker OR kubernetes", "python NOT java" - Prefix: "deploy*" Returns matching messages with session metadata, content snippet, and surrounding context (1 message before and after the match). """ if not query or not query.strip(): return [] query = self._sanitize_fts5_query(query) if not query: return [] # Build WHERE clauses dynamically where_clauses = ["messages_fts MATCH ?"] params: list = [query] if source_filter is not None: source_placeholders = ",".join("?" for _ in source_filter) where_clauses.append(f"s.source IN ({source_placeholders})") params.extend(source_filter) if exclude_sources is not None: exclude_placeholders = ",".join("?" for _ in exclude_sources) where_clauses.append(f"s.source NOT IN ({exclude_placeholders})") params.extend(exclude_sources) if role_filter: role_placeholders = ",".join("?" for _ in role_filter) where_clauses.append(f"m.role IN ({role_placeholders})") params.extend(role_filter) where_sql = " AND ".join(where_clauses) params.extend([limit, offset]) sql = f""" SELECT m.id, m.session_id, m.role, snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet, m.content, m.timestamp, m.tool_name, s.source, s.model, s.started_at AS session_started FROM messages_fts JOIN messages m ON m.id = messages_fts.rowid JOIN sessions s ON s.id = m.session_id WHERE {where_sql} ORDER BY rank LIMIT ? OFFSET ? """ def _do_search(conn): try: cursor = conn.execute(sql, params) except sqlite3.OperationalError: # FTS5 query syntax error despite sanitization — return empty return [] return [dict(row) for row in cursor.fetchall()] matches = self._execute_read(_do_search) # Add surrounding context (1 message before + after each match). # Done outside the lock so we don't hold it across N sequential queries. for match in matches: try: def _do_context(conn): ctx_cursor = conn.execute( """SELECT role, content FROM messages WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1 ORDER BY id""", (match["session_id"], match["id"], match["id"]), ) return [ {"role": r["role"], "content": (r["content"] or "")[:200]} for r in ctx_cursor.fetchall() ] match["context"] = self._execute_read(_do_context) except Exception: match["context"] = [] # Remove full content from result (snippet is enough, saves tokens) for match in matches: match.pop("content", None) return matches def search_sessions( self, source: str = None, limit: int = 20, offset: int = 0, ) -> List[Dict[str, Any]]: """List sessions, optionally filtered by source.""" def _do(conn): if source: cursor = conn.execute( "SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?", (source, limit, offset), ) else: cursor = conn.execute( "SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?", (limit, offset), ) return [dict(row) for row in cursor.fetchall()] return self._execute_read(_do) # ========================================================================= # Utility # ========================================================================= def session_count(self, source: str = None) -> int: """Count sessions, optionally filtered by source.""" def _do(conn): if source: cursor = conn.execute( "SELECT COUNT(*) FROM sessions WHERE source = ?", (source,) ) else: cursor = conn.execute("SELECT COUNT(*) FROM sessions") return cursor.fetchone()[0] return self._execute_read(_do) def message_count(self, session_id: str = None) -> int: """Count messages, optionally for a specific session.""" def _do(conn): if session_id: cursor = conn.execute( "SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,) ) else: cursor = conn.execute("SELECT COUNT(*) FROM messages") return cursor.fetchone()[0] return self._execute_read(_do) # ========================================================================= # Export and cleanup # ========================================================================= def export_session(self, session_id: str) -> Optional[Dict[str, Any]]: """Export a single session with all its messages as a dict.""" session = self.get_session(session_id) if not session: return None messages = self.get_messages(session_id) return {**session, "messages": messages} def export_all(self, source: str = None) -> List[Dict[str, Any]]: """ Export all sessions (with messages) as a list of dicts. Suitable for writing to a JSONL file for backup/analysis. """ sessions = self.search_sessions(source=source, limit=100000) results = [] for session in sessions: messages = self.get_messages(session["id"]) results.append({**session, "messages": messages}) return results def clear_messages(self, session_id: str) -> None: """Delete all messages for a session and reset its counters.""" def _do(conn): conn.execute( "DELETE FROM messages WHERE session_id = ?", (session_id,) ) conn.execute( "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", (session_id,), ) self._execute_write(_do) def delete_session(self, session_id: str) -> bool: """Delete a session and all its messages. Returns True if found.""" def _do(conn): cursor = conn.execute( "SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,) ) if cursor.fetchone()[0] == 0: return False conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) return True return self._execute_write(_do) def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int: """ Delete sessions older than N days. Returns count of deleted sessions. Only prunes ended sessions (not active ones). """ cutoff = time.time() - (older_than_days * 86400) def _do(conn): if source: cursor = conn.execute( """SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""", (cutoff, source), ) else: cursor = conn.execute( "SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL", (cutoff,), ) session_ids = [row["id"] for row in cursor.fetchall()] for sid in session_ids: conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,)) conn.execute("DELETE FROM sessions WHERE id = ?", (sid,)) return len(session_ids) return self._execute_write(_do) # ========================================================================= # Statistics and diagnostics # ========================================================================= def get_stats(self) -> Dict[str, Any]: """Get database statistics including write batcher stats.""" batcher_stats = self._write_batcher.get_stats() def _do(conn): # Get database stats cursor = conn.execute("PRAGMA wal_checkpoint") checkpoint_info = cursor.fetchone() cursor = conn.execute("SELECT COUNT(*) FROM sessions") session_count = cursor.fetchone()[0] cursor = conn.execute("SELECT COUNT(*) FROM messages") message_count = cursor.fetchone()[0] return { "checkpoint": { "busy": checkpoint_info[0] if checkpoint_info else None, "log": checkpoint_info[1] if checkpoint_info else None, "checkpointed": checkpoint_info[2] if checkpoint_info else None, }, "sessions": session_count, "messages": message_count, } db_stats = self._execute_read(_do) return { **batcher_stats, **db_stats, "pool_size": self._pool.pool_size, "write_count": self._write_count, }