[loop-cycle-50] refactor: replace bare sqlite3.connect() with context managers batch 2 (#157) (#180)

This commit is contained in:
2026-03-15 11:58:43 -04:00
parent bea2749158
commit bcd6d7e321
16 changed files with 512 additions and 510 deletions

View File

@@ -16,6 +16,8 @@ import json
import logging
import sqlite3
import uuid
from collections.abc import Generator
from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
@@ -39,28 +41,31 @@ class Prediction:
evaluated_at: str | None
def _get_conn() -> sqlite3.Connection:
@contextmanager
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("""
CREATE TABLE IF NOT EXISTS spark_predictions (
id TEXT PRIMARY KEY,
task_id TEXT NOT NULL,
prediction_type TEXT NOT NULL,
predicted_value TEXT NOT NULL,
actual_value TEXT,
accuracy REAL,
created_at TEXT NOT NULL,
evaluated_at TEXT
with closing(sqlite3.connect(str(DB_PATH))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("""
CREATE TABLE IF NOT EXISTS spark_predictions (
id TEXT PRIMARY KEY,
task_id TEXT NOT NULL,
prediction_type TEXT NOT NULL,
predicted_value TEXT NOT NULL,
actual_value TEXT,
accuracy REAL,
created_at TEXT NOT NULL,
evaluated_at TEXT
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)")
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)"
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)")
conn.commit()
return conn
conn.commit()
yield conn
# ── Prediction phase ────────────────────────────────────────────────────────
@@ -119,17 +124,16 @@ def predict_task_outcome(
# Store prediction
pred_id = str(uuid.uuid4())
now = datetime.now(UTC).isoformat()
conn = _get_conn()
conn.execute(
"""
INSERT INTO spark_predictions
(id, task_id, prediction_type, predicted_value, created_at)
VALUES (?, ?, ?, ?, ?)
""",
(pred_id, task_id, "outcome", json.dumps(prediction), now),
)
conn.commit()
conn.close()
with _get_conn() as conn:
conn.execute(
"""
INSERT INTO spark_predictions
(id, task_id, prediction_type, predicted_value, created_at)
VALUES (?, ?, ?, ?, ?)
""",
(pred_id, task_id, "outcome", json.dumps(prediction), now),
)
conn.commit()
prediction["prediction_id"] = pred_id
return prediction
@@ -148,41 +152,39 @@ def evaluate_prediction(
Returns the evaluation result or None if no prediction exists.
"""
conn = _get_conn()
row = conn.execute(
"""
SELECT * FROM spark_predictions
WHERE task_id = ? AND prediction_type = 'outcome' AND evaluated_at IS NULL
ORDER BY created_at DESC LIMIT 1
""",
(task_id,),
).fetchone()
with _get_conn() as conn:
row = conn.execute(
"""
SELECT * FROM spark_predictions
WHERE task_id = ? AND prediction_type = 'outcome' AND evaluated_at IS NULL
ORDER BY created_at DESC LIMIT 1
""",
(task_id,),
).fetchone()
if not row:
conn.close()
return None
if not row:
return None
predicted = json.loads(row["predicted_value"])
actual = {
"winner": actual_winner,
"succeeded": task_succeeded,
"winning_bid": winning_bid,
}
predicted = json.loads(row["predicted_value"])
actual = {
"winner": actual_winner,
"succeeded": task_succeeded,
"winning_bid": winning_bid,
}
# Calculate accuracy
accuracy = _compute_accuracy(predicted, actual)
now = datetime.now(UTC).isoformat()
# Calculate accuracy
accuracy = _compute_accuracy(predicted, actual)
now = datetime.now(UTC).isoformat()
conn.execute(
"""
UPDATE spark_predictions
SET actual_value = ?, accuracy = ?, evaluated_at = ?
WHERE id = ?
""",
(json.dumps(actual), accuracy, now, row["id"]),
)
conn.commit()
conn.close()
conn.execute(
"""
UPDATE spark_predictions
SET actual_value = ?, accuracy = ?, evaluated_at = ?
WHERE id = ?
""",
(json.dumps(actual), accuracy, now, row["id"]),
)
conn.commit()
return {
"prediction_id": row["id"],
@@ -243,7 +245,6 @@ def get_predictions(
limit: int = 50,
) -> list[Prediction]:
"""Query stored predictions."""
conn = _get_conn()
query = "SELECT * FROM spark_predictions WHERE 1=1"
params: list = []
@@ -256,8 +257,8 @@ def get_predictions(
query += " ORDER BY created_at DESC LIMIT ?"
params.append(limit)
rows = conn.execute(query, params).fetchall()
conn.close()
with _get_conn() as conn:
rows = conn.execute(query, params).fetchall()
return [
Prediction(
id=r["id"],
@@ -275,17 +276,16 @@ def get_predictions(
def get_accuracy_stats() -> dict:
"""Return aggregate accuracy statistics for the EIDOS loop."""
conn = _get_conn()
row = conn.execute("""
SELECT
COUNT(*) AS total_predictions,
COUNT(evaluated_at) AS evaluated,
AVG(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS avg_accuracy,
MIN(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS min_accuracy,
MAX(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS max_accuracy
FROM spark_predictions
""").fetchone()
conn.close()
with _get_conn() as conn:
row = conn.execute("""
SELECT
COUNT(*) AS total_predictions,
COUNT(evaluated_at) AS evaluated,
AVG(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS avg_accuracy,
MIN(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS min_accuracy,
MAX(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS max_accuracy
FROM spark_predictions
""").fetchone()
return {
"total_predictions": row["total_predictions"] or 0,