This commit was merged in pull request #830.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -73,7 +73,6 @@ morning_briefing.txt
|
|||||||
markdown_report.md
|
markdown_report.md
|
||||||
data/timmy_soul.jsonl
|
data/timmy_soul.jsonl
|
||||||
scripts/migrate_to_zeroclaw.py
|
scripts/migrate_to_zeroclaw.py
|
||||||
src/infrastructure/db_pool.py
|
|
||||||
workspace/
|
workspace/
|
||||||
|
|
||||||
# Loop orchestration state
|
# Loop orchestration state
|
||||||
|
|||||||
84
src/infrastructure/db_pool.py
Normal file
84
src/infrastructure/db_pool.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Thread-local SQLite connection pool.
|
||||||
|
|
||||||
|
Provides a ConnectionPool class that manages SQLite connections per thread,
|
||||||
|
with support for context managers and automatic cleanup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from collections.abc import Generator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionPool:
|
||||||
|
"""Thread-local SQLite connection pool.
|
||||||
|
|
||||||
|
Each thread gets its own connection, which is reused for subsequent
|
||||||
|
requests from the same thread. Connections are automatically cleaned
|
||||||
|
up when close_connection() is called or the context manager exits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path | str) -> None:
|
||||||
|
"""Initialize the connection pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the SQLite database file.
|
||||||
|
"""
|
||||||
|
self._db_path = Path(db_path)
|
||||||
|
self._local = threading.local()
|
||||||
|
|
||||||
|
def _ensure_db_exists(self) -> None:
|
||||||
|
"""Ensure the database directory exists."""
|
||||||
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def get_connection(self) -> sqlite3.Connection:
|
||||||
|
"""Get a connection for the current thread.
|
||||||
|
|
||||||
|
Creates a new connection if one doesn't exist for this thread,
|
||||||
|
otherwise returns the existing connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sqlite3 Connection object.
|
||||||
|
"""
|
||||||
|
if not hasattr(self._local, "conn") or self._local.conn is None:
|
||||||
|
self._ensure_db_exists()
|
||||||
|
self._local.conn = sqlite3.connect(str(self._db_path), check_same_thread=False)
|
||||||
|
self._local.conn.row_factory = sqlite3.Row
|
||||||
|
return self._local.conn
|
||||||
|
|
||||||
|
def close_connection(self) -> None:
|
||||||
|
"""Close the connection for the current thread.
|
||||||
|
|
||||||
|
Cleans up the thread-local storage. Safe to call even if
|
||||||
|
no connection exists for this thread.
|
||||||
|
"""
|
||||||
|
if hasattr(self._local, "conn") and self._local.conn is not None:
|
||||||
|
self._local.conn.close()
|
||||||
|
self._local.conn = None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def connection(self) -> Generator[sqlite3.Connection, None, None]:
|
||||||
|
"""Context manager for getting and automatically closing a connection.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A sqlite3 Connection object.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with pool.connection() as conn:
|
||||||
|
cursor = conn.execute("SELECT 1")
|
||||||
|
result = cursor.fetchone()
|
||||||
|
"""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
self.close_connection()
|
||||||
|
|
||||||
|
def close_all(self) -> None:
|
||||||
|
"""Close all connections (useful for testing).
|
||||||
|
|
||||||
|
Note: This only closes the connection for the current thread.
|
||||||
|
In a multi-threaded environment, each thread must close its own.
|
||||||
|
"""
|
||||||
|
self.close_connection()
|
||||||
288
tests/infrastructure/test_db_pool.py
Normal file
288
tests/infrastructure/test_db_pool.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""Tests for infrastructure.db_pool module."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from infrastructure.db_pool import ConnectionPool
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionPoolInit:
|
||||||
|
"""Test ConnectionPool initialization."""
|
||||||
|
|
||||||
|
def test_init_with_string_path(self, tmp_path):
|
||||||
|
"""Pool can be initialized with a string path."""
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
pool = ConnectionPool(db_path)
|
||||||
|
assert pool._db_path == Path(db_path)
|
||||||
|
|
||||||
|
def test_init_with_path_object(self, tmp_path):
|
||||||
|
"""Pool can be initialized with a Path object."""
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
pool = ConnectionPool(db_path)
|
||||||
|
assert pool._db_path == db_path
|
||||||
|
|
||||||
|
def test_init_creates_thread_local(self, tmp_path):
|
||||||
|
"""Pool initializes thread-local storage."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
assert hasattr(pool, "_local")
|
||||||
|
assert isinstance(pool._local, threading.local)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetConnection:
|
||||||
|
"""Test get_connection() method."""
|
||||||
|
|
||||||
|
def test_get_connection_returns_valid_sqlite3_connection(self, tmp_path):
|
||||||
|
"""get_connection() returns a valid sqlite3 connection."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
conn = pool.get_connection()
|
||||||
|
assert isinstance(conn, sqlite3.Connection)
|
||||||
|
# Verify it's a working connection
|
||||||
|
cursor = conn.execute("SELECT 1")
|
||||||
|
assert cursor.fetchone()[0] == 1
|
||||||
|
|
||||||
|
def test_get_connection_creates_db_file(self, tmp_path):
|
||||||
|
"""get_connection() creates the database file if it doesn't exist."""
|
||||||
|
db_path = tmp_path / "subdir" / "test.db"
|
||||||
|
assert not db_path.exists()
|
||||||
|
pool = ConnectionPool(db_path)
|
||||||
|
pool.get_connection()
|
||||||
|
assert db_path.exists()
|
||||||
|
|
||||||
|
def test_get_connection_sets_row_factory(self, tmp_path):
|
||||||
|
"""get_connection() sets row_factory to sqlite3.Row."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
conn = pool.get_connection()
|
||||||
|
assert conn.row_factory is sqlite3.Row
|
||||||
|
|
||||||
|
def test_multiple_calls_same_thread_reuse_connection(self, tmp_path):
|
||||||
|
"""Multiple calls from same thread reuse the same connection."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
conn1 = pool.get_connection()
|
||||||
|
conn2 = pool.get_connection()
|
||||||
|
assert conn1 is conn2
|
||||||
|
|
||||||
|
def test_different_threads_get_different_connections(self, tmp_path):
|
||||||
|
"""Different threads get different connections."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
connections = []
|
||||||
|
|
||||||
|
def get_conn():
|
||||||
|
connections.append(pool.get_connection())
|
||||||
|
|
||||||
|
t1 = threading.Thread(target=get_conn)
|
||||||
|
t2 = threading.Thread(target=get_conn)
|
||||||
|
t1.start()
|
||||||
|
t2.start()
|
||||||
|
t1.join()
|
||||||
|
t2.join()
|
||||||
|
|
||||||
|
assert len(connections) == 2
|
||||||
|
assert connections[0] is not connections[1]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCloseConnection:
|
||||||
|
"""Test close_connection() method."""
|
||||||
|
|
||||||
|
def test_close_connection_closes_sqlite_connection(self, tmp_path):
|
||||||
|
"""close_connection() closes the underlying sqlite connection."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
conn = pool.get_connection()
|
||||||
|
pool.close_connection()
|
||||||
|
# Connection should be closed
|
||||||
|
with pytest.raises(sqlite3.ProgrammingError):
|
||||||
|
conn.execute("SELECT 1")
|
||||||
|
|
||||||
|
def test_close_connection_cleans_up_thread_local(self, tmp_path):
|
||||||
|
"""close_connection() cleans up thread-local storage."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
pool.get_connection()
|
||||||
|
assert hasattr(pool._local, "conn")
|
||||||
|
assert pool._local.conn is not None
|
||||||
|
|
||||||
|
pool.close_connection()
|
||||||
|
|
||||||
|
# Should either not have the attr or it should be None
|
||||||
|
assert not hasattr(pool._local, "conn") or pool._local.conn is None
|
||||||
|
|
||||||
|
def test_close_connection_without_getting_connection_is_safe(self, tmp_path):
|
||||||
|
"""close_connection() is safe to call even without getting a connection first."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
# Should not raise
|
||||||
|
pool.close_connection()
|
||||||
|
|
||||||
|
def test_close_connection_multiple_calls_is_safe(self, tmp_path):
|
||||||
|
"""close_connection() can be called multiple times safely."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
pool.get_connection()
|
||||||
|
pool.close_connection()
|
||||||
|
# Should not raise
|
||||||
|
pool.close_connection()
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextManager:
|
||||||
|
"""Test the connection() context manager."""
|
||||||
|
|
||||||
|
def test_connection_yields_valid_connection(self, tmp_path):
|
||||||
|
"""connection() context manager yields a valid sqlite3 connection."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
with pool.connection() as conn:
|
||||||
|
assert isinstance(conn, sqlite3.Connection)
|
||||||
|
cursor = conn.execute("SELECT 42")
|
||||||
|
assert cursor.fetchone()[0] == 42
|
||||||
|
|
||||||
|
def test_connection_closes_on_exit(self, tmp_path):
|
||||||
|
"""connection() context manager closes connection on exit."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
with pool.connection() as conn:
|
||||||
|
pass
|
||||||
|
# Connection should be closed after context exit
|
||||||
|
with pytest.raises(sqlite3.ProgrammingError):
|
||||||
|
conn.execute("SELECT 1")
|
||||||
|
|
||||||
|
def test_connection_closes_on_exception(self, tmp_path):
|
||||||
|
"""connection() context manager closes connection even on exception."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
conn_ref = None
|
||||||
|
try:
|
||||||
|
with pool.connection() as conn:
|
||||||
|
conn_ref = conn
|
||||||
|
raise ValueError("Test exception")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
# Connection should still be closed
|
||||||
|
with pytest.raises(sqlite3.ProgrammingError):
|
||||||
|
conn_ref.execute("SELECT 1")
|
||||||
|
|
||||||
|
def test_connection_context_manager_is_reusable(self, tmp_path):
|
||||||
|
"""connection() context manager can be used multiple times."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
|
||||||
|
with pool.connection() as conn1:
|
||||||
|
result1 = conn1.execute("SELECT 1").fetchone()[0]
|
||||||
|
|
||||||
|
with pool.connection() as conn2:
|
||||||
|
result2 = conn2.execute("SELECT 2").fetchone()[0]
|
||||||
|
|
||||||
|
assert result1 == 1
|
||||||
|
assert result2 == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadSafety:
|
||||||
|
"""Test thread-safety of the connection pool."""
|
||||||
|
|
||||||
|
def test_concurrent_access(self, tmp_path):
|
||||||
|
"""Multiple threads can use the pool concurrently."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
results = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def worker(worker_id):
|
||||||
|
try:
|
||||||
|
with pool.connection() as conn:
|
||||||
|
conn.execute("CREATE TABLE IF NOT EXISTS test (id INTEGER)")
|
||||||
|
conn.execute("INSERT INTO test VALUES (?)", (worker_id,))
|
||||||
|
conn.commit()
|
||||||
|
time.sleep(0.01) # Small delay to increase contention
|
||||||
|
results.append(worker_id)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||||
|
assert len(results) == 5
|
||||||
|
|
||||||
|
def test_thread_isolation(self, tmp_path):
|
||||||
|
"""Each thread has isolated connections (verified by thread-local data)."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
def worker(worker_id):
|
||||||
|
# Get connection and write worker-specific data
|
||||||
|
conn = pool.get_connection()
|
||||||
|
conn.execute("CREATE TABLE IF NOT EXISTS isolation_test (thread_id INTEGER)")
|
||||||
|
conn.execute("DELETE FROM isolation_test") # Clear previous data
|
||||||
|
conn.execute("INSERT INTO isolation_test VALUES (?)", (worker_id,))
|
||||||
|
conn.commit()
|
||||||
|
# Read back the data
|
||||||
|
result = conn.execute("SELECT thread_id FROM isolation_test").fetchone()[0]
|
||||||
|
results.append((worker_id, result))
|
||||||
|
pool.close_connection()
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Each thread should have written and read its own ID
|
||||||
|
assert len(results) == 3
|
||||||
|
for worker_id, read_id in results:
|
||||||
|
assert worker_id == read_id, f"Thread {worker_id} read {read_id} instead"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCloseAll:
|
||||||
|
"""Test close_all() method."""
|
||||||
|
|
||||||
|
def test_close_all_closes_current_thread_connection(self, tmp_path):
|
||||||
|
"""close_all() closes the connection for the current thread."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
conn = pool.get_connection()
|
||||||
|
pool.close_all()
|
||||||
|
# Connection should be closed
|
||||||
|
with pytest.raises(sqlite3.ProgrammingError):
|
||||||
|
conn.execute("SELECT 1")
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegration:
|
||||||
|
"""Integration tests for real-world usage patterns."""
|
||||||
|
|
||||||
|
def test_basic_crud_operations(self, tmp_path):
|
||||||
|
"""Can perform basic CRUD operations through the pool."""
|
||||||
|
pool = ConnectionPool(tmp_path / "test.db")
|
||||||
|
|
||||||
|
with pool.connection() as conn:
|
||||||
|
# Create table
|
||||||
|
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
|
||||||
|
# Insert
|
||||||
|
conn.execute("INSERT INTO users (name) VALUES (?)", ("Alice",))
|
||||||
|
conn.execute("INSERT INTO users (name) VALUES (?)", ("Bob",))
|
||||||
|
conn.commit()
|
||||||
|
# Query
|
||||||
|
cursor = conn.execute("SELECT * FROM users ORDER BY id")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
assert len(rows) == 2
|
||||||
|
assert rows[0]["name"] == "Alice"
|
||||||
|
assert rows[1]["name"] == "Bob"
|
||||||
|
|
||||||
|
def test_multiple_pools_different_databases(self, tmp_path):
|
||||||
|
"""Multiple pools can manage different databases independently."""
|
||||||
|
pool1 = ConnectionPool(tmp_path / "db1.db")
|
||||||
|
pool2 = ConnectionPool(tmp_path / "db2.db")
|
||||||
|
|
||||||
|
with pool1.connection() as conn1:
|
||||||
|
conn1.execute("CREATE TABLE test (val INTEGER)")
|
||||||
|
conn1.execute("INSERT INTO test VALUES (1)")
|
||||||
|
conn1.commit()
|
||||||
|
|
||||||
|
with pool2.connection() as conn2:
|
||||||
|
conn2.execute("CREATE TABLE test (val INTEGER)")
|
||||||
|
conn2.execute("INSERT INTO test VALUES (2)")
|
||||||
|
conn2.commit()
|
||||||
|
|
||||||
|
# Verify isolation
|
||||||
|
with pool1.connection() as conn1:
|
||||||
|
result = conn1.execute("SELECT val FROM test").fetchone()[0]
|
||||||
|
assert result == 1
|
||||||
|
|
||||||
|
with pool2.connection() as conn2:
|
||||||
|
result = conn2.execute("SELECT val FROM test").fetchone()[0]
|
||||||
|
assert result == 2
|
||||||
Reference in New Issue
Block a user