"""Tests for infrastructure.db_pool module.""" import sqlite3 import threading import time from pathlib import Path import pytest from src.config import settings from src.infrastructure.db_pool import ConnectionPool class TestConnectionPoolInit: """Test ConnectionPool initialization.""" def test_init_with_string_path(self, tmp_path): """Pool can be initialized with a string path.""" db_path = str(tmp_path / "test.db") pool = ConnectionPool(db_path) assert pool._db_path == Path(db_path) def test_init_with_path_object(self, tmp_path): """Pool can be initialized with a Path object.""" db_path = tmp_path / "test.db" pool = ConnectionPool(db_path) assert pool._db_path == db_path def test_init_creates_thread_local(self, tmp_path): """Pool initializes thread-local storage.""" pool = ConnectionPool(tmp_path / "test.db") assert hasattr(pool, "_local") assert isinstance(pool._local, threading.local) class TestGetConnection: """Test get_connection() method.""" def test_get_connection_returns_valid_sqlite3_connection(self, tmp_path): """get_connection() returns a valid sqlite3 connection.""" pool = ConnectionPool(tmp_path / "test.db") conn = pool.get_connection() assert isinstance(conn, sqlite3.Connection) # Verify it's a working connection cursor = conn.execute("SELECT 1") assert cursor.fetchone()[0] == 1 def test_get_connection_creates_db_file(self, tmp_path): """get_connection() creates the database file if it doesn't exist.""" db_path = tmp_path / "subdir" / "test.db" assert not db_path.exists() pool = ConnectionPool(db_path) pool.get_connection() assert db_path.exists() def test_get_connection_sets_row_factory(self, tmp_path): """get_connection() sets row_factory to sqlite3.Row.""" pool = ConnectionPool(tmp_path / "test.db") conn = pool.get_connection() assert conn.row_factory is sqlite3.Row def test_multiple_calls_same_thread_reuse_connection(self, tmp_path): """Multiple calls from same thread reuse the same connection.""" pool = ConnectionPool(tmp_path / "test.db") conn1 = pool.get_connection() conn2 = pool.get_connection() assert conn1 is conn2 def test_different_threads_get_different_connections(self, tmp_path): """Different threads get different connections.""" pool = ConnectionPool(tmp_path / "test.db") connections = [] def get_conn(): connections.append(pool.get_connection()) t1 = threading.Thread(target=get_conn) t2 = threading.Thread(target=get_conn) t1.start() t2.start() t1.join() t2.join() assert len(connections) == 2 assert connections[0] is not connections[1] class TestCloseConnection: """Test close_connection() method.""" def test_close_connection_closes_sqlite_connection(self, tmp_path): """close_connection() closes the underlying sqlite connection.""" pool = ConnectionPool(tmp_path / "test.db") conn = pool.get_connection() pool.close_connection() # Connection should be closed with pytest.raises(sqlite3.ProgrammingError): conn.execute("SELECT 1") def test_close_connection_cleans_up_thread_local(self, tmp_path): """close_connection() cleans up thread-local storage.""" pool = ConnectionPool(tmp_path / "test.db") pool.get_connection() assert hasattr(pool._local, "conn") assert pool._local.conn is not None pool.close_connection() # Should either not have the attr or it should be None assert not hasattr(pool._local, "conn") or pool._local.conn is None def test_close_connection_without_getting_connection_is_safe(self, tmp_path): """close_connection() is safe to call even without getting a connection first.""" pool = ConnectionPool(tmp_path / "test.db") # Should not raise pool.close_connection() def test_close_connection_multiple_calls_is_safe(self, tmp_path): """close_connection() can be called multiple times safely.""" pool = ConnectionPool(tmp_path / "test.db") pool.get_connection() pool.close_connection() # Should not raise pool.close_connection() class TestContextManager: """Test the connection() context manager.""" def test_connection_yields_valid_connection(self, tmp_path): """connection() context manager yields a valid sqlite3 connection.""" pool = ConnectionPool(tmp_path / "test.db") with pool.connection() as conn: assert isinstance(conn, sqlite3.Connection) cursor = conn.execute("SELECT 42") assert cursor.fetchone()[0] == 42 def test_connection_closes_on_exit(self, tmp_path): """connection() context manager closes connection on exit.""" pool = ConnectionPool(tmp_path / "test.db") with pool.connection() as conn: pass # Connection should be closed after context exit with pytest.raises(sqlite3.ProgrammingError): conn.execute("SELECT 1") def test_connection_closes_on_exception(self, tmp_path): """connection() context manager closes connection even on exception.""" pool = ConnectionPool(tmp_path / "test.db") conn_ref = None try: with pool.connection() as conn: conn_ref = conn raise ValueError("Test exception") except ValueError: pass # Connection should still be closed with pytest.raises(sqlite3.ProgrammingError): conn_ref.execute("SELECT 1") def test_connection_context_manager_is_reusable(self, tmp_path): """connection() context manager can be used multiple times.""" pool = ConnectionPool(tmp_path / "test.db") with pool.connection() as conn1: result1 = conn1.execute("SELECT 1").fetchone()[0] with pool.connection() as conn2: result2 = conn2.execute("SELECT 2").fetchone()[0] assert result1 == 1 assert result2 == 2 class TestThreadSafety: """Test thread-safety of the connection pool.""" def test_concurrent_access(self, tmp_path): """Multiple threads can use the pool concurrently.""" pool = ConnectionPool(tmp_path / "test.db") results = [] errors = [] def worker(worker_id): try: with pool.connection() as conn: conn.execute("CREATE TABLE IF NOT EXISTS test (id INTEGER)") conn.execute("INSERT INTO test VALUES (?)", (worker_id,)) conn.commit() time.sleep(0.01) # Small delay to increase contention results.append(worker_id) except Exception as e: errors.append(e) threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] for t in threads: t.start() for t in threads: t.join() assert len(errors) == 0, f"Errors occurred: {errors}" assert len(results) == 5 def test_thread_isolation(self, tmp_path): """Each thread has isolated connections (verified by thread-local data).""" pool = ConnectionPool(tmp_path / "test.db") results = [] def worker(worker_id): # Get connection and write worker-specific data conn = pool.get_connection() conn.execute("CREATE TABLE IF NOT EXISTS isolation_test (thread_id INTEGER)") conn.execute("DELETE FROM isolation_test") # Clear previous data conn.execute("INSERT INTO isolation_test VALUES (?)", (worker_id,)) conn.commit() # Read back the data result = conn.execute("SELECT thread_id FROM isolation_test").fetchone()[0] results.append((worker_id, result)) pool.close_connection() threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)] for t in threads: t.start() for t in threads: t.join() # Each thread should have written and read its own ID assert len(results) == 3 for worker_id, read_id in results: assert worker_id == read_id, f"Thread {worker_id} read {read_id} instead" class TestCloseAll: """Test close_all() method.""" def test_close_all_closes_current_thread_connection(self, tmp_path): """close_all() closes the connection for the current thread.""" pool = ConnectionPool(tmp_path / "test.db") conn = pool.get_connection() pool.close_all() # Connection should be closed with pytest.raises(sqlite3.ProgrammingError): conn.execute("SELECT 1") class TestConnectionLeaks: """Test that connections do not leak.""" def test_get_connection_after_close_returns_fresh_connection(self, tmp_path): """After close, get_connection() returns a new working connection.""" pool = ConnectionPool(tmp_path / "test.db") conn1 = pool.get_connection() pool.close_connection() conn2 = pool.get_connection() assert conn2 is not conn1 # New connection must be usable cursor = conn2.execute("SELECT 1") assert cursor.fetchone()[0] == 1 pool.close_connection() def test_context_manager_does_not_leak_connection(self, tmp_path): """After context manager exit, thread-local conn is cleared.""" pool = ConnectionPool(tmp_path / "test.db") with pool.connection(): pass # Thread-local should be cleaned up assert pool._local.conn is None def test_context_manager_exception_does_not_leak_connection(self, tmp_path): """Connection is cleaned up even when an exception occurs.""" pool = ConnectionPool(tmp_path / "test.db") try: with pool.connection(): raise RuntimeError("boom") except RuntimeError: pass assert pool._local.conn is None def test_threads_do_not_leak_into_each_other(self, tmp_path): """A connection opened in one thread is invisible to another.""" pool = ConnectionPool(tmp_path / "test.db") # Open a connection on main thread pool.get_connection() visible_from_other_thread = [] def check(): has_conn = hasattr(pool._local, "conn") and pool._local.conn is not None visible_from_other_thread.append(has_conn) t = threading.Thread(target=check) t.start() t.join() assert visible_from_other_thread == [False] pool.close_connection() def test_repeated_open_close_cycles(self, tmp_path): """Repeated open/close cycles do not accumulate leaked connections.""" pool = ConnectionPool(tmp_path / "test.db") for _ in range(50): with pool.connection() as conn: conn.execute("SELECT 1") # After each cycle, connection should be cleaned up assert pool._local.conn is None class TestPragmaApplication: """Test that SQLite pragmas can be applied and persist on pooled connections. The codebase uses WAL journal mode and busy_timeout pragmas on connections obtained from the pool. These tests verify that pattern works correctly. """ def test_wal_journal_mode_persists(self, tmp_path): """WAL journal mode set on a pooled connection persists for its lifetime.""" pool = ConnectionPool(tmp_path / "test.db") conn = pool.get_connection() conn.execute("PRAGMA journal_mode=WAL") mode = conn.execute("PRAGMA journal_mode").fetchone()[0] assert mode == "wal" # Same connection should retain the pragma same_conn = pool.get_connection() mode2 = same_conn.execute("PRAGMA journal_mode").fetchone()[0] assert mode2 == "wal" pool.close_connection() def test_busy_timeout_persists(self, tmp_path): """busy_timeout pragma set on a pooled connection persists.""" pool = ConnectionPool(tmp_path / "test.db") conn = pool.get_connection() conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}") timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0] assert timeout == settings.db_busy_timeout_ms pool.close_connection() def test_pragmas_apply_per_connection(self, tmp_path): """Pragmas set on one thread's connection are independent of another's.""" pool = ConnectionPool(tmp_path / "test.db") conn_main = pool.get_connection() conn_main.execute("PRAGMA cache_size=9999") other_cache = [] def check_pragma(): conn = pool.get_connection() # Don't set cache_size — should get the default, not 9999 val = conn.execute("PRAGMA cache_size").fetchone()[0] other_cache.append(val) pool.close_connection() t = threading.Thread(target=check_pragma) t.start() t.join() # Other thread's connection should NOT have our custom cache_size assert other_cache[0] != 9999 pool.close_connection() def test_session_pragma_resets_on_new_connection(self, tmp_path): """Session-level pragmas (cache_size) reset on a new connection.""" pool = ConnectionPool(tmp_path / "test.db") conn1 = pool.get_connection() conn1.execute("PRAGMA cache_size=9999") assert conn1.execute("PRAGMA cache_size").fetchone()[0] == 9999 pool.close_connection() conn2 = pool.get_connection() cache = conn2.execute("PRAGMA cache_size").fetchone()[0] # New connection gets default cache_size, not the previous value assert cache != 9999 pool.close_connection() def test_wal_mode_via_context_manager(self, tmp_path): """WAL mode can be set within a context manager block.""" pool = ConnectionPool(tmp_path / "test.db") with pool.connection() as conn: conn.execute("PRAGMA journal_mode=WAL") mode = conn.execute("PRAGMA journal_mode").fetchone()[0] assert mode == "wal" class TestIntegration: """Integration tests for real-world usage patterns.""" def test_basic_crud_operations(self, tmp_path): """Can perform basic CRUD operations through the pool.""" pool = ConnectionPool(tmp_path / "test.db") with pool.connection() as conn: # Create table conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)") # Insert conn.execute("INSERT INTO users (name) VALUES (?)", ("Alice",)) conn.execute("INSERT INTO users (name) VALUES (?)", ("Bob",)) conn.commit() # Query cursor = conn.execute("SELECT * FROM users ORDER BY id") rows = cursor.fetchall() assert len(rows) == 2 assert rows[0]["name"] == "Alice" assert rows[1]["name"] == "Bob" def test_multiple_pools_different_databases(self, tmp_path): """Multiple pools can manage different databases independently.""" pool1 = ConnectionPool(tmp_path / "db1.db") pool2 = ConnectionPool(tmp_path / "db2.db") with pool1.connection() as conn1: conn1.execute("CREATE TABLE test (val INTEGER)") conn1.execute("INSERT INTO test VALUES (1)") conn1.commit() with pool2.connection() as conn2: conn2.execute("CREATE TABLE test (val INTEGER)") conn2.execute("INSERT INTO test VALUES (2)") conn2.commit() # Verify isolation with pool1.connection() as conn1: result = conn1.execute("SELECT val FROM test").fetchone()[0] assert result == 1 with pool2.connection() as conn2: result = conn2.execute("SELECT val FROM test").fetchone()[0] assert result == 2