diff --git a/tests/infrastructure/test_db_pool.py b/tests/infrastructure/test_db_pool.py index 9e1f9bac..c1212042 100644 --- a/tests/infrastructure/test_db_pool.py +++ b/tests/infrastructure/test_db_pool.py @@ -242,6 +242,145 @@ class TestCloseAll: conn.execute("SELECT 1") +class TestConnectionLeaks: + """Test that connections do not leak.""" + + def test_get_connection_after_close_returns_fresh_connection(self, tmp_path): + """After close, get_connection() returns a new working connection.""" + pool = ConnectionPool(tmp_path / "test.db") + conn1 = pool.get_connection() + pool.close_connection() + + conn2 = pool.get_connection() + assert conn2 is not conn1 + # New connection must be usable + cursor = conn2.execute("SELECT 1") + assert cursor.fetchone()[0] == 1 + pool.close_connection() + + def test_context_manager_does_not_leak_connection(self, tmp_path): + """After context manager exit, thread-local conn is cleared.""" + pool = ConnectionPool(tmp_path / "test.db") + with pool.connection(): + pass + # Thread-local should be cleaned up + assert pool._local.conn is None + + def test_context_manager_exception_does_not_leak_connection(self, tmp_path): + """Connection is cleaned up even when an exception occurs.""" + pool = ConnectionPool(tmp_path / "test.db") + try: + with pool.connection(): + raise RuntimeError("boom") + except RuntimeError: + pass + assert pool._local.conn is None + + def test_threads_do_not_leak_into_each_other(self, tmp_path): + """A connection opened in one thread is invisible to another.""" + pool = ConnectionPool(tmp_path / "test.db") + # Open a connection on main thread + pool.get_connection() + + visible_from_other_thread = [] + + def check(): + has_conn = hasattr(pool._local, "conn") and pool._local.conn is not None + visible_from_other_thread.append(has_conn) + + t = threading.Thread(target=check) + t.start() + t.join() + + assert visible_from_other_thread == [False] + pool.close_connection() + + def test_repeated_open_close_cycles(self, tmp_path): + """Repeated open/close cycles do not accumulate leaked connections.""" + pool = ConnectionPool(tmp_path / "test.db") + for _ in range(50): + with pool.connection() as conn: + conn.execute("SELECT 1") + # After each cycle, connection should be cleaned up + assert pool._local.conn is None + + +class TestPragmaApplication: + """Test that SQLite pragmas can be applied and persist on pooled connections. + + The codebase uses WAL journal mode and busy_timeout pragmas on connections + obtained from the pool. These tests verify that pattern works correctly. + """ + + def test_wal_journal_mode_persists(self, tmp_path): + """WAL journal mode set on a pooled connection persists for its lifetime.""" + pool = ConnectionPool(tmp_path / "test.db") + conn = pool.get_connection() + conn.execute("PRAGMA journal_mode=WAL") + mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + assert mode == "wal" + + # Same connection should retain the pragma + same_conn = pool.get_connection() + mode2 = same_conn.execute("PRAGMA journal_mode").fetchone()[0] + assert mode2 == "wal" + pool.close_connection() + + def test_busy_timeout_persists(self, tmp_path): + """busy_timeout pragma set on a pooled connection persists.""" + pool = ConnectionPool(tmp_path / "test.db") + conn = pool.get_connection() + conn.execute("PRAGMA busy_timeout=5000") + timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0] + assert timeout == 5000 + pool.close_connection() + + def test_pragmas_apply_per_connection(self, tmp_path): + """Pragmas set on one thread's connection are independent of another's.""" + pool = ConnectionPool(tmp_path / "test.db") + conn_main = pool.get_connection() + conn_main.execute("PRAGMA cache_size=9999") + + other_cache = [] + + def check_pragma(): + conn = pool.get_connection() + # Don't set cache_size — should get the default, not 9999 + val = conn.execute("PRAGMA cache_size").fetchone()[0] + other_cache.append(val) + pool.close_connection() + + t = threading.Thread(target=check_pragma) + t.start() + t.join() + + # Other thread's connection should NOT have our custom cache_size + assert other_cache[0] != 9999 + pool.close_connection() + + def test_session_pragma_resets_on_new_connection(self, tmp_path): + """Session-level pragmas (cache_size) reset on a new connection.""" + pool = ConnectionPool(tmp_path / "test.db") + conn1 = pool.get_connection() + conn1.execute("PRAGMA cache_size=9999") + assert conn1.execute("PRAGMA cache_size").fetchone()[0] == 9999 + pool.close_connection() + + conn2 = pool.get_connection() + cache = conn2.execute("PRAGMA cache_size").fetchone()[0] + # New connection gets default cache_size, not the previous value + assert cache != 9999 + pool.close_connection() + + def test_wal_mode_via_context_manager(self, tmp_path): + """WAL mode can be set within a context manager block.""" + pool = ConnectionPool(tmp_path / "test.db") + with pool.connection() as conn: + conn.execute("PRAGMA journal_mode=WAL") + mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + assert mode == "wal" + + class TestIntegration: """Integration tests for real-world usage patterns."""