From 892d773f6329f513a99ac991de5f6fda02bf9dd8 Mon Sep 17 00:00:00 2001 From: kimi Date: Sat, 21 Mar 2026 18:00:34 -0400 Subject: [PATCH] feat: add ConnectionPool class with thread-local SQLite connections - Add src/infrastructure/db_pool.py with ConnectionPool class - Implements thread-local storage for SQLite connections - Provides get_connection(), close_connection(), and connection() context manager - Multiple calls from same thread reuse the same connection - Different threads get different connections Refs #769 --- .gitignore | 1 - src/infrastructure/db_pool.py | 84 ++++++++ tests/infrastructure/test_db_pool.py | 288 +++++++++++++++++++++++++++ 3 files changed, 372 insertions(+), 1 deletion(-) create mode 100644 src/infrastructure/db_pool.py create mode 100644 tests/infrastructure/test_db_pool.py diff --git a/.gitignore b/.gitignore index 49e143e3..c3590105 100644 --- a/.gitignore +++ b/.gitignore @@ -73,7 +73,6 @@ morning_briefing.txt markdown_report.md data/timmy_soul.jsonl scripts/migrate_to_zeroclaw.py -src/infrastructure/db_pool.py workspace/ # Loop orchestration state diff --git a/src/infrastructure/db_pool.py b/src/infrastructure/db_pool.py new file mode 100644 index 00000000..1927a703 --- /dev/null +++ b/src/infrastructure/db_pool.py @@ -0,0 +1,84 @@ +"""Thread-local SQLite connection pool. + +Provides a ConnectionPool class that manages SQLite connections per thread, +with support for context managers and automatic cleanup. +""" + +import sqlite3 +import threading +from collections.abc import Generator +from contextlib import contextmanager +from pathlib import Path + + +class ConnectionPool: + """Thread-local SQLite connection pool. + + Each thread gets its own connection, which is reused for subsequent + requests from the same thread. Connections are automatically cleaned + up when close_connection() is called or the context manager exits. + """ + + def __init__(self, db_path: Path | str) -> None: + """Initialize the connection pool. + + Args: + db_path: Path to the SQLite database file. + """ + self._db_path = Path(db_path) + self._local = threading.local() + + def _ensure_db_exists(self) -> None: + """Ensure the database directory exists.""" + self._db_path.parent.mkdir(parents=True, exist_ok=True) + + def get_connection(self) -> sqlite3.Connection: + """Get a connection for the current thread. + + Creates a new connection if one doesn't exist for this thread, + otherwise returns the existing connection. + + Returns: + A sqlite3 Connection object. + """ + if not hasattr(self._local, "conn") or self._local.conn is None: + self._ensure_db_exists() + self._local.conn = sqlite3.connect(str(self._db_path), check_same_thread=False) + self._local.conn.row_factory = sqlite3.Row + return self._local.conn + + def close_connection(self) -> None: + """Close the connection for the current thread. + + Cleans up the thread-local storage. Safe to call even if + no connection exists for this thread. + """ + if hasattr(self._local, "conn") and self._local.conn is not None: + self._local.conn.close() + self._local.conn = None + + @contextmanager + def connection(self) -> Generator[sqlite3.Connection, None, None]: + """Context manager for getting and automatically closing a connection. + + Yields: + A sqlite3 Connection object. + + Example: + with pool.connection() as conn: + cursor = conn.execute("SELECT 1") + result = cursor.fetchone() + """ + conn = self.get_connection() + try: + yield conn + finally: + self.close_connection() + + def close_all(self) -> None: + """Close all connections (useful for testing). + + Note: This only closes the connection for the current thread. + In a multi-threaded environment, each thread must close its own. + """ + self.close_connection() diff --git a/tests/infrastructure/test_db_pool.py b/tests/infrastructure/test_db_pool.py new file mode 100644 index 00000000..9e1f9bac --- /dev/null +++ b/tests/infrastructure/test_db_pool.py @@ -0,0 +1,288 @@ +"""Tests for infrastructure.db_pool module.""" + +import sqlite3 +import threading +import time +from pathlib import Path + +import pytest + +from infrastructure.db_pool import ConnectionPool + + +class TestConnectionPoolInit: + """Test ConnectionPool initialization.""" + + def test_init_with_string_path(self, tmp_path): + """Pool can be initialized with a string path.""" + db_path = str(tmp_path / "test.db") + pool = ConnectionPool(db_path) + assert pool._db_path == Path(db_path) + + def test_init_with_path_object(self, tmp_path): + """Pool can be initialized with a Path object.""" + db_path = tmp_path / "test.db" + pool = ConnectionPool(db_path) + assert pool._db_path == db_path + + def test_init_creates_thread_local(self, tmp_path): + """Pool initializes thread-local storage.""" + pool = ConnectionPool(tmp_path / "test.db") + assert hasattr(pool, "_local") + assert isinstance(pool._local, threading.local) + + +class TestGetConnection: + """Test get_connection() method.""" + + def test_get_connection_returns_valid_sqlite3_connection(self, tmp_path): + """get_connection() returns a valid sqlite3 connection.""" + pool = ConnectionPool(tmp_path / "test.db") + conn = pool.get_connection() + assert isinstance(conn, sqlite3.Connection) + # Verify it's a working connection + cursor = conn.execute("SELECT 1") + assert cursor.fetchone()[0] == 1 + + def test_get_connection_creates_db_file(self, tmp_path): + """get_connection() creates the database file if it doesn't exist.""" + db_path = tmp_path / "subdir" / "test.db" + assert not db_path.exists() + pool = ConnectionPool(db_path) + pool.get_connection() + assert db_path.exists() + + def test_get_connection_sets_row_factory(self, tmp_path): + """get_connection() sets row_factory to sqlite3.Row.""" + pool = ConnectionPool(tmp_path / "test.db") + conn = pool.get_connection() + assert conn.row_factory is sqlite3.Row + + def test_multiple_calls_same_thread_reuse_connection(self, tmp_path): + """Multiple calls from same thread reuse the same connection.""" + pool = ConnectionPool(tmp_path / "test.db") + conn1 = pool.get_connection() + conn2 = pool.get_connection() + assert conn1 is conn2 + + def test_different_threads_get_different_connections(self, tmp_path): + """Different threads get different connections.""" + pool = ConnectionPool(tmp_path / "test.db") + connections = [] + + def get_conn(): + connections.append(pool.get_connection()) + + t1 = threading.Thread(target=get_conn) + t2 = threading.Thread(target=get_conn) + t1.start() + t2.start() + t1.join() + t2.join() + + assert len(connections) == 2 + assert connections[0] is not connections[1] + + +class TestCloseConnection: + """Test close_connection() method.""" + + def test_close_connection_closes_sqlite_connection(self, tmp_path): + """close_connection() closes the underlying sqlite connection.""" + pool = ConnectionPool(tmp_path / "test.db") + conn = pool.get_connection() + pool.close_connection() + # Connection should be closed + with pytest.raises(sqlite3.ProgrammingError): + conn.execute("SELECT 1") + + def test_close_connection_cleans_up_thread_local(self, tmp_path): + """close_connection() cleans up thread-local storage.""" + pool = ConnectionPool(tmp_path / "test.db") + pool.get_connection() + assert hasattr(pool._local, "conn") + assert pool._local.conn is not None + + pool.close_connection() + + # Should either not have the attr or it should be None + assert not hasattr(pool._local, "conn") or pool._local.conn is None + + def test_close_connection_without_getting_connection_is_safe(self, tmp_path): + """close_connection() is safe to call even without getting a connection first.""" + pool = ConnectionPool(tmp_path / "test.db") + # Should not raise + pool.close_connection() + + def test_close_connection_multiple_calls_is_safe(self, tmp_path): + """close_connection() can be called multiple times safely.""" + pool = ConnectionPool(tmp_path / "test.db") + pool.get_connection() + pool.close_connection() + # Should not raise + pool.close_connection() + + +class TestContextManager: + """Test the connection() context manager.""" + + def test_connection_yields_valid_connection(self, tmp_path): + """connection() context manager yields a valid sqlite3 connection.""" + pool = ConnectionPool(tmp_path / "test.db") + with pool.connection() as conn: + assert isinstance(conn, sqlite3.Connection) + cursor = conn.execute("SELECT 42") + assert cursor.fetchone()[0] == 42 + + def test_connection_closes_on_exit(self, tmp_path): + """connection() context manager closes connection on exit.""" + pool = ConnectionPool(tmp_path / "test.db") + with pool.connection() as conn: + pass + # Connection should be closed after context exit + with pytest.raises(sqlite3.ProgrammingError): + conn.execute("SELECT 1") + + def test_connection_closes_on_exception(self, tmp_path): + """connection() context manager closes connection even on exception.""" + pool = ConnectionPool(tmp_path / "test.db") + conn_ref = None + try: + with pool.connection() as conn: + conn_ref = conn + raise ValueError("Test exception") + except ValueError: + pass + # Connection should still be closed + with pytest.raises(sqlite3.ProgrammingError): + conn_ref.execute("SELECT 1") + + def test_connection_context_manager_is_reusable(self, tmp_path): + """connection() context manager can be used multiple times.""" + pool = ConnectionPool(tmp_path / "test.db") + + with pool.connection() as conn1: + result1 = conn1.execute("SELECT 1").fetchone()[0] + + with pool.connection() as conn2: + result2 = conn2.execute("SELECT 2").fetchone()[0] + + assert result1 == 1 + assert result2 == 2 + + +class TestThreadSafety: + """Test thread-safety of the connection pool.""" + + def test_concurrent_access(self, tmp_path): + """Multiple threads can use the pool concurrently.""" + pool = ConnectionPool(tmp_path / "test.db") + results = [] + errors = [] + + def worker(worker_id): + try: + with pool.connection() as conn: + conn.execute("CREATE TABLE IF NOT EXISTS test (id INTEGER)") + conn.execute("INSERT INTO test VALUES (?)", (worker_id,)) + conn.commit() + time.sleep(0.01) # Small delay to increase contention + results.append(worker_id) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0, f"Errors occurred: {errors}" + assert len(results) == 5 + + def test_thread_isolation(self, tmp_path): + """Each thread has isolated connections (verified by thread-local data).""" + pool = ConnectionPool(tmp_path / "test.db") + results = [] + + def worker(worker_id): + # Get connection and write worker-specific data + conn = pool.get_connection() + conn.execute("CREATE TABLE IF NOT EXISTS isolation_test (thread_id INTEGER)") + conn.execute("DELETE FROM isolation_test") # Clear previous data + conn.execute("INSERT INTO isolation_test VALUES (?)", (worker_id,)) + conn.commit() + # Read back the data + result = conn.execute("SELECT thread_id FROM isolation_test").fetchone()[0] + results.append((worker_id, result)) + pool.close_connection() + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Each thread should have written and read its own ID + assert len(results) == 3 + for worker_id, read_id in results: + assert worker_id == read_id, f"Thread {worker_id} read {read_id} instead" + + +class TestCloseAll: + """Test close_all() method.""" + + def test_close_all_closes_current_thread_connection(self, tmp_path): + """close_all() closes the connection for the current thread.""" + pool = ConnectionPool(tmp_path / "test.db") + conn = pool.get_connection() + pool.close_all() + # Connection should be closed + with pytest.raises(sqlite3.ProgrammingError): + conn.execute("SELECT 1") + + +class TestIntegration: + """Integration tests for real-world usage patterns.""" + + def test_basic_crud_operations(self, tmp_path): + """Can perform basic CRUD operations through the pool.""" + pool = ConnectionPool(tmp_path / "test.db") + + with pool.connection() as conn: + # Create table + conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)") + # Insert + conn.execute("INSERT INTO users (name) VALUES (?)", ("Alice",)) + conn.execute("INSERT INTO users (name) VALUES (?)", ("Bob",)) + conn.commit() + # Query + cursor = conn.execute("SELECT * FROM users ORDER BY id") + rows = cursor.fetchall() + assert len(rows) == 2 + assert rows[0]["name"] == "Alice" + assert rows[1]["name"] == "Bob" + + def test_multiple_pools_different_databases(self, tmp_path): + """Multiple pools can manage different databases independently.""" + pool1 = ConnectionPool(tmp_path / "db1.db") + pool2 = ConnectionPool(tmp_path / "db2.db") + + with pool1.connection() as conn1: + conn1.execute("CREATE TABLE test (val INTEGER)") + conn1.execute("INSERT INTO test VALUES (1)") + conn1.commit() + + with pool2.connection() as conn2: + conn2.execute("CREATE TABLE test (val INTEGER)") + conn2.execute("INSERT INTO test VALUES (2)") + conn2.commit() + + # Verify isolation + with pool1.connection() as conn1: + result = conn1.execute("SELECT val FROM test").fetchone()[0] + assert result == 1 + + with pool2.connection() as conn2: + result = conn2.execute("SELECT val FROM test").fetchone()[0] + assert result == 2 -- 2.43.0