[kimi] test: add comprehensive tests for spark modules (#659) #695
327
tests/spark/test_advisor.py
Normal file
327
tests/spark/test_advisor.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""Comprehensive tests for spark.advisor module.
|
||||
|
||||
Covers all advisory-generation helpers:
|
||||
- _check_failure_patterns (grouped agent failures)
|
||||
- _check_agent_performance (top / struggling agents)
|
||||
- _check_bid_patterns (spread + high average)
|
||||
- _check_prediction_accuracy (low / high accuracy)
|
||||
- _check_system_activity (idle / tasks-posted-but-no-completions)
|
||||
- generate_advisories (integration, sorting, min-events guard)
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from spark.advisor import (
|
||||
_MIN_EVENTS,
|
||||
Advisory,
|
||||
_check_agent_performance,
|
||||
_check_bid_patterns,
|
||||
_check_failure_patterns,
|
||||
_check_prediction_accuracy,
|
||||
_check_system_activity,
|
||||
generate_advisories,
|
||||
)
|
||||
from spark.memory import record_event
|
||||
|
||||
# ── Advisory dataclass ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAdvisoryDataclass:
|
||||
def test_defaults(self):
|
||||
a = Advisory(
|
||||
category="test",
|
||||
priority=0.5,
|
||||
title="T",
|
||||
detail="D",
|
||||
suggested_action="A",
|
||||
)
|
||||
assert a.subject is None
|
||||
assert a.evidence_count == 0
|
||||
|
||||
def test_all_fields(self):
|
||||
a = Advisory(
|
||||
category="c",
|
||||
priority=0.9,
|
||||
title="T",
|
||||
detail="D",
|
||||
suggested_action="A",
|
||||
subject="agent-1",
|
||||
evidence_count=7,
|
||||
)
|
||||
assert a.subject == "agent-1"
|
||||
assert a.evidence_count == 7
|
||||
|
||||
|
||||
# ── _check_failure_patterns ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCheckFailurePatterns:
|
||||
def test_no_failures_returns_empty(self):
|
||||
assert _check_failure_patterns() == []
|
||||
|
||||
def test_single_failure_not_enough(self):
|
||||
record_event("task_failed", "once", agent_id="a1", task_id="t1")
|
||||
assert _check_failure_patterns() == []
|
||||
|
||||
def test_two_failures_triggers_advisory(self):
|
||||
for i in range(2):
|
||||
record_event("task_failed", f"fail {i}", agent_id="agent-abc", task_id=f"t{i}")
|
||||
results = _check_failure_patterns()
|
||||
assert len(results) == 1
|
||||
assert results[0].category == "failure_prevention"
|
||||
assert results[0].subject == "agent-abc"
|
||||
assert results[0].evidence_count == 2
|
||||
|
||||
def test_priority_scales_with_count(self):
|
||||
for i in range(5):
|
||||
record_event("task_failed", f"fail {i}", agent_id="agent-x", task_id=f"f{i}")
|
||||
results = _check_failure_patterns()
|
||||
assert len(results) == 1
|
||||
assert results[0].priority > 0.5
|
||||
|
||||
def test_priority_capped_at_one(self):
|
||||
for i in range(20):
|
||||
record_event("task_failed", f"fail {i}", agent_id="agent-y", task_id=f"ff{i}")
|
||||
results = _check_failure_patterns()
|
||||
assert results[0].priority <= 1.0
|
||||
|
||||
def test_multiple_agents_separate_advisories(self):
|
||||
for i in range(3):
|
||||
record_event("task_failed", f"a fail {i}", agent_id="agent-a", task_id=f"a{i}")
|
||||
record_event("task_failed", f"b fail {i}", agent_id="agent-b", task_id=f"b{i}")
|
||||
results = _check_failure_patterns()
|
||||
assert len(results) == 2
|
||||
subjects = {r.subject for r in results}
|
||||
assert subjects == {"agent-a", "agent-b"}
|
||||
|
||||
def test_events_without_agent_id_skipped(self):
|
||||
for i in range(3):
|
||||
record_event("task_failed", f"no-agent {i}", task_id=f"na{i}")
|
||||
assert _check_failure_patterns() == []
|
||||
|
||||
|
||||
# ── _check_agent_performance ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCheckAgentPerformance:
|
||||
def test_no_events_returns_empty(self):
|
||||
assert _check_agent_performance() == []
|
||||
|
||||
def test_too_few_tasks_skipped(self):
|
||||
record_event("task_completed", "done", agent_id="agent-1", task_id="t1")
|
||||
assert _check_agent_performance() == []
|
||||
|
||||
def test_high_performer_detected(self):
|
||||
for i in range(4):
|
||||
record_event("task_completed", f"done {i}", agent_id="agent-star", task_id=f"s{i}")
|
||||
results = _check_agent_performance()
|
||||
perf = [r for r in results if r.category == "agent_performance"]
|
||||
assert len(perf) == 1
|
||||
assert "excels" in perf[0].title
|
||||
assert perf[0].subject == "agent-star"
|
||||
|
||||
def test_struggling_agent_detected(self):
|
||||
# 1 success, 4 failures = 20% rate
|
||||
record_event("task_completed", "ok", agent_id="agent-bad", task_id="ok1")
|
||||
for i in range(4):
|
||||
record_event("task_failed", f"nope {i}", agent_id="agent-bad", task_id=f"bad{i}")
|
||||
results = _check_agent_performance()
|
||||
struggling = [r for r in results if "struggling" in r.title]
|
||||
assert len(struggling) == 1
|
||||
assert struggling[0].priority > 0.5
|
||||
|
||||
def test_middling_agent_no_advisory(self):
|
||||
# 50% success rate — neither excelling nor struggling
|
||||
for i in range(3):
|
||||
record_event("task_completed", f"ok {i}", agent_id="agent-mid", task_id=f"m{i}")
|
||||
for i in range(3):
|
||||
record_event("task_failed", f"nope {i}", agent_id="agent-mid", task_id=f"mf{i}")
|
||||
results = _check_agent_performance()
|
||||
mid_advisories = [r for r in results if r.subject == "agent-mid"]
|
||||
assert mid_advisories == []
|
||||
|
||||
def test_events_without_agent_id_skipped(self):
|
||||
for i in range(5):
|
||||
record_event("task_completed", f"done {i}", task_id=f"no-agent-{i}")
|
||||
assert _check_agent_performance() == []
|
||||
|
||||
|
||||
# ── _check_bid_patterns ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCheckBidPatterns:
|
||||
def _record_bids(self, amounts):
|
||||
for i, sats in enumerate(amounts):
|
||||
record_event(
|
||||
"bid_submitted",
|
||||
f"bid {i}",
|
||||
agent_id=f"a{i}",
|
||||
task_id=f"bt{i}",
|
||||
data=json.dumps({"bid_sats": sats}),
|
||||
)
|
||||
|
||||
def test_too_few_bids_returns_empty(self):
|
||||
self._record_bids([10, 20, 30])
|
||||
assert _check_bid_patterns() == []
|
||||
|
||||
def test_wide_spread_detected(self):
|
||||
# avg=50, spread=90 > 50*1.5=75
|
||||
self._record_bids([5, 10, 50, 90, 95])
|
||||
results = _check_bid_patterns()
|
||||
spread_advisories = [r for r in results if "spread" in r.title.lower()]
|
||||
assert len(spread_advisories) == 1
|
||||
|
||||
def test_high_average_detected(self):
|
||||
self._record_bids([80, 85, 90, 95, 100])
|
||||
results = _check_bid_patterns()
|
||||
high_avg = [r for r in results if "High average" in r.title]
|
||||
assert len(high_avg) == 1
|
||||
|
||||
def test_normal_bids_no_advisory(self):
|
||||
# Tight spread, low average
|
||||
self._record_bids([30, 32, 28, 31, 29])
|
||||
results = _check_bid_patterns()
|
||||
assert results == []
|
||||
|
||||
def test_invalid_json_data_skipped(self):
|
||||
for i in range(6):
|
||||
record_event(
|
||||
"bid_submitted",
|
||||
f"bid {i}",
|
||||
agent_id=f"a{i}",
|
||||
task_id=f"inv{i}",
|
||||
data="not-json",
|
||||
)
|
||||
results = _check_bid_patterns()
|
||||
assert results == []
|
||||
|
||||
def test_zero_bid_sats_skipped(self):
|
||||
for i in range(6):
|
||||
record_event(
|
||||
"bid_submitted",
|
||||
f"bid {i}",
|
||||
data=json.dumps({"bid_sats": 0}),
|
||||
)
|
||||
assert _check_bid_patterns() == []
|
||||
|
||||
def test_both_spread_and_high_avg(self):
|
||||
# Wide spread AND high average: avg=82, spread=150 > 82*1.5=123
|
||||
self._record_bids([5, 80, 90, 100, 155])
|
||||
results = _check_bid_patterns()
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
# ── _check_prediction_accuracy ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCheckPredictionAccuracy:
|
||||
def test_too_few_evaluations(self):
|
||||
assert _check_prediction_accuracy() == []
|
||||
|
||||
def test_low_accuracy_advisory(self):
|
||||
from spark.eidos import evaluate_prediction, predict_task_outcome
|
||||
|
||||
for i in range(4):
|
||||
predict_task_outcome(f"pa-{i}", "task", ["agent-a"])
|
||||
evaluate_prediction(f"pa-{i}", "agent-wrong", task_succeeded=False, winning_bid=999)
|
||||
results = _check_prediction_accuracy()
|
||||
low = [r for r in results if "Low prediction" in r.title]
|
||||
assert len(low) == 1
|
||||
assert low[0].priority > 0.5
|
||||
|
||||
def test_high_accuracy_advisory(self):
|
||||
from spark.eidos import evaluate_prediction, predict_task_outcome
|
||||
|
||||
for i in range(4):
|
||||
predict_task_outcome(f"ph-{i}", "task", ["agent-a"])
|
||||
evaluate_prediction(f"ph-{i}", "agent-a", task_succeeded=True, winning_bid=30)
|
||||
results = _check_prediction_accuracy()
|
||||
high = [r for r in results if "Strong prediction" in r.title]
|
||||
assert len(high) == 1
|
||||
|
||||
def test_middling_accuracy_no_advisory(self):
|
||||
from spark.eidos import evaluate_prediction, predict_task_outcome
|
||||
|
||||
# Mix of correct and incorrect to get ~0.5 accuracy
|
||||
for i in range(3):
|
||||
predict_task_outcome(f"pm-{i}", "task", ["agent-a"])
|
||||
evaluate_prediction(f"pm-{i}", "agent-a", task_succeeded=True, winning_bid=30)
|
||||
for i in range(3):
|
||||
predict_task_outcome(f"pmx-{i}", "task", ["agent-a"])
|
||||
evaluate_prediction(f"pmx-{i}", "agent-wrong", task_succeeded=False, winning_bid=999)
|
||||
results = _check_prediction_accuracy()
|
||||
# avg should be middling — neither low nor high advisory
|
||||
low = [r for r in results if "Low" in r.title]
|
||||
high = [r for r in results if "Strong" in r.title]
|
||||
# At least one side should be empty (depends on exact accuracy)
|
||||
assert not (low and high)
|
||||
|
||||
|
||||
# ── _check_system_activity ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCheckSystemActivity:
|
||||
def test_no_events_idle_advisory(self):
|
||||
results = _check_system_activity()
|
||||
assert len(results) == 1
|
||||
assert "No swarm activity" in results[0].title
|
||||
|
||||
def test_has_events_no_idle_advisory(self):
|
||||
record_event("task_completed", "done", task_id="t1")
|
||||
results = _check_system_activity()
|
||||
idle = [r for r in results if "No swarm activity" in r.title]
|
||||
assert idle == []
|
||||
|
||||
def test_tasks_posted_but_none_completing(self):
|
||||
for i in range(5):
|
||||
record_event("task_posted", f"posted {i}", task_id=f"tp{i}")
|
||||
results = _check_system_activity()
|
||||
stalled = [r for r in results if "none completing" in r.title.lower()]
|
||||
assert len(stalled) == 1
|
||||
assert stalled[0].evidence_count >= 4
|
||||
|
||||
def test_posts_with_completions_no_stalled_advisory(self):
|
||||
for i in range(5):
|
||||
record_event("task_posted", f"posted {i}", task_id=f"tpx{i}")
|
||||
record_event("task_completed", "done", task_id="tpx0")
|
||||
results = _check_system_activity()
|
||||
stalled = [r for r in results if "none completing" in r.title.lower()]
|
||||
assert stalled == []
|
||||
|
||||
|
||||
# ── generate_advisories (integration) ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestGenerateAdvisories:
|
||||
def test_below_min_events_returns_insufficient(self):
|
||||
advisories = generate_advisories()
|
||||
assert len(advisories) >= 1
|
||||
assert advisories[0].title == "Insufficient data"
|
||||
assert advisories[0].evidence_count == 0
|
||||
|
||||
def test_exactly_at_min_events_proceeds(self):
|
||||
for i in range(_MIN_EVENTS):
|
||||
record_event("task_posted", f"ev {i}", task_id=f"min{i}")
|
||||
advisories = generate_advisories()
|
||||
insufficient = [a for a in advisories if a.title == "Insufficient data"]
|
||||
assert insufficient == []
|
||||
|
||||
def test_results_sorted_by_priority_descending(self):
|
||||
for i in range(5):
|
||||
record_event("task_posted", f"posted {i}", task_id=f"sp{i}")
|
||||
for i in range(3):
|
||||
record_event("task_failed", f"fail {i}", agent_id="agent-fail", task_id=f"sf{i}")
|
||||
advisories = generate_advisories()
|
||||
if len(advisories) >= 2:
|
||||
for i in range(len(advisories) - 1):
|
||||
assert advisories[i].priority >= advisories[i + 1].priority
|
||||
|
||||
def test_multiple_categories_produced(self):
|
||||
# Create failures + posted-no-completions
|
||||
for i in range(5):
|
||||
record_event("task_failed", f"fail {i}", agent_id="agent-bad", task_id=f"mf{i}")
|
||||
for i in range(5):
|
||||
record_event("task_posted", f"posted {i}", task_id=f"mp{i}")
|
||||
advisories = generate_advisories()
|
||||
categories = {a.category for a in advisories}
|
||||
assert len(categories) >= 2
|
||||
299
tests/spark/test_eidos.py
Normal file
299
tests/spark/test_eidos.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""Comprehensive tests for spark.eidos module.
|
||||
|
||||
Covers:
|
||||
- _get_conn (schema creation, WAL, busy timeout)
|
||||
- predict_task_outcome (baseline, with history, edge cases)
|
||||
- evaluate_prediction (correct, wrong, missing, double-eval)
|
||||
- _compute_accuracy (all components, edge cases)
|
||||
- get_predictions (filters: task_id, evaluated_only, limit)
|
||||
- get_accuracy_stats (empty, after evaluations)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from spark.eidos import (
|
||||
Prediction,
|
||||
_compute_accuracy,
|
||||
evaluate_prediction,
|
||||
get_accuracy_stats,
|
||||
get_predictions,
|
||||
predict_task_outcome,
|
||||
)
|
||||
|
||||
# ── Prediction dataclass ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPredictionDataclass:
|
||||
def test_defaults(self):
|
||||
p = Prediction(
|
||||
id="1",
|
||||
task_id="t1",
|
||||
prediction_type="outcome",
|
||||
predicted_value="{}",
|
||||
actual_value=None,
|
||||
accuracy=None,
|
||||
created_at="2026-01-01",
|
||||
evaluated_at=None,
|
||||
)
|
||||
assert p.actual_value is None
|
||||
assert p.accuracy is None
|
||||
|
||||
|
||||
# ── predict_task_outcome ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPredictTaskOutcome:
|
||||
def test_baseline_no_history(self):
|
||||
result = predict_task_outcome("t-base", "Do stuff", ["a1", "a2"])
|
||||
assert result["likely_winner"] == "a1"
|
||||
assert result["success_probability"] == 0.7
|
||||
assert result["estimated_bid_range"] == [20, 80]
|
||||
assert "baseline" in result["reasoning"]
|
||||
assert "prediction_id" in result
|
||||
|
||||
def test_empty_candidates(self):
|
||||
result = predict_task_outcome("t-empty", "Nothing", [])
|
||||
assert result["likely_winner"] is None
|
||||
|
||||
def test_history_selects_best_agent(self):
|
||||
history = {
|
||||
"a1": {"success_rate": 0.3, "avg_winning_bid": 40},
|
||||
"a2": {"success_rate": 0.95, "avg_winning_bid": 50},
|
||||
}
|
||||
result = predict_task_outcome("t-hist", "Task", ["a1", "a2"], agent_history=history)
|
||||
assert result["likely_winner"] == "a2"
|
||||
assert result["success_probability"] > 0.7
|
||||
|
||||
def test_history_agent_not_in_candidates_ignored(self):
|
||||
history = {
|
||||
"a-outside": {"success_rate": 0.99, "avg_winning_bid": 10},
|
||||
}
|
||||
result = predict_task_outcome("t-out", "Task", ["a1"], agent_history=history)
|
||||
# a-outside not in candidates, so falls back to baseline
|
||||
assert result["likely_winner"] == "a1"
|
||||
|
||||
def test_history_adjusts_bid_range(self):
|
||||
history = {
|
||||
"a1": {"success_rate": 0.5, "avg_winning_bid": 100},
|
||||
"a2": {"success_rate": 0.8, "avg_winning_bid": 200},
|
||||
}
|
||||
result = predict_task_outcome("t-bid", "Task", ["a1", "a2"], agent_history=history)
|
||||
low, high = result["estimated_bid_range"]
|
||||
assert low == max(1, int(100 * 0.8))
|
||||
assert high == int(200 * 1.2)
|
||||
|
||||
def test_history_with_zero_avg_bid_skipped(self):
|
||||
history = {
|
||||
"a1": {"success_rate": 0.8, "avg_winning_bid": 0},
|
||||
}
|
||||
result = predict_task_outcome("t-zero-bid", "Task", ["a1"], agent_history=history)
|
||||
# Zero avg_winning_bid should be skipped, keep default range
|
||||
assert result["estimated_bid_range"] == [20, 80]
|
||||
|
||||
def test_prediction_stored_in_db(self):
|
||||
result = predict_task_outcome("t-db", "Store me", ["a1"])
|
||||
preds = get_predictions(task_id="t-db")
|
||||
assert len(preds) == 1
|
||||
assert preds[0].id == result["prediction_id"]
|
||||
assert preds[0].prediction_type == "outcome"
|
||||
|
||||
def test_success_probability_clamped(self):
|
||||
history = {
|
||||
"a1": {"success_rate": 1.5, "avg_winning_bid": 50},
|
||||
}
|
||||
result = predict_task_outcome("t-clamp", "Task", ["a1"], agent_history=history)
|
||||
assert result["success_probability"] <= 1.0
|
||||
|
||||
|
||||
# ── evaluate_prediction ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEvaluatePrediction:
|
||||
def test_correct_prediction(self):
|
||||
predict_task_outcome("t-eval-ok", "Task", ["a1"])
|
||||
result = evaluate_prediction("t-eval-ok", "a1", task_succeeded=True, winning_bid=30)
|
||||
assert result is not None
|
||||
assert 0.0 <= result["accuracy"] <= 1.0
|
||||
assert result["actual"]["winner"] == "a1"
|
||||
assert result["actual"]["succeeded"] is True
|
||||
|
||||
def test_wrong_prediction(self):
|
||||
predict_task_outcome("t-eval-wrong", "Task", ["a1"])
|
||||
result = evaluate_prediction("t-eval-wrong", "a2", task_succeeded=False)
|
||||
assert result is not None
|
||||
assert result["accuracy"] < 1.0
|
||||
|
||||
def test_no_prediction_returns_none(self):
|
||||
result = evaluate_prediction("nonexistent", "a1", task_succeeded=True)
|
||||
assert result is None
|
||||
|
||||
def test_double_evaluation_returns_none(self):
|
||||
predict_task_outcome("t-double", "Task", ["a1"])
|
||||
evaluate_prediction("t-double", "a1", task_succeeded=True)
|
||||
result = evaluate_prediction("t-double", "a1", task_succeeded=True)
|
||||
assert result is None
|
||||
|
||||
def test_evaluation_updates_db(self):
|
||||
predict_task_outcome("t-upd", "Task", ["a1"])
|
||||
evaluate_prediction("t-upd", "a1", task_succeeded=True, winning_bid=50)
|
||||
preds = get_predictions(task_id="t-upd", evaluated_only=True)
|
||||
assert len(preds) == 1
|
||||
assert preds[0].accuracy is not None
|
||||
assert preds[0].actual_value is not None
|
||||
assert preds[0].evaluated_at is not None
|
||||
|
||||
def test_winning_bid_none(self):
|
||||
predict_task_outcome("t-nobid", "Task", ["a1"])
|
||||
result = evaluate_prediction("t-nobid", "a1", task_succeeded=True)
|
||||
assert result is not None
|
||||
assert result["actual"]["winning_bid"] is None
|
||||
|
||||
|
||||
# ── _compute_accuracy ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestComputeAccuracy:
|
||||
def test_perfect_match(self):
|
||||
predicted = {
|
||||
"likely_winner": "a1",
|
||||
"success_probability": 1.0,
|
||||
"estimated_bid_range": [20, 40],
|
||||
}
|
||||
actual = {"winner": "a1", "succeeded": True, "winning_bid": 30}
|
||||
assert _compute_accuracy(predicted, actual) == pytest.approx(1.0, abs=0.01)
|
||||
|
||||
def test_all_wrong(self):
|
||||
predicted = {
|
||||
"likely_winner": "a1",
|
||||
"success_probability": 1.0,
|
||||
"estimated_bid_range": [10, 20],
|
||||
}
|
||||
actual = {"winner": "a2", "succeeded": False, "winning_bid": 100}
|
||||
assert _compute_accuracy(predicted, actual) < 0.3
|
||||
|
||||
def test_no_winner_in_predicted(self):
|
||||
predicted = {"success_probability": 0.5, "estimated_bid_range": [20, 40]}
|
||||
actual = {"winner": "a1", "succeeded": True, "winning_bid": 30}
|
||||
acc = _compute_accuracy(predicted, actual)
|
||||
# Winner component skipped, success + bid counted
|
||||
assert 0.0 <= acc <= 1.0
|
||||
|
||||
def test_no_winner_in_actual(self):
|
||||
predicted = {"likely_winner": "a1", "success_probability": 0.5}
|
||||
actual = {"succeeded": True}
|
||||
acc = _compute_accuracy(predicted, actual)
|
||||
assert 0.0 <= acc <= 1.0
|
||||
|
||||
def test_bid_outside_range_partial_credit(self):
|
||||
predicted = {
|
||||
"likely_winner": "a1",
|
||||
"success_probability": 1.0,
|
||||
"estimated_bid_range": [20, 40],
|
||||
}
|
||||
# Bid just outside range
|
||||
actual = {"winner": "a1", "succeeded": True, "winning_bid": 45}
|
||||
acc = _compute_accuracy(predicted, actual)
|
||||
assert 0.5 < acc < 1.0
|
||||
|
||||
def test_bid_far_outside_range(self):
|
||||
predicted = {
|
||||
"likely_winner": "a1",
|
||||
"success_probability": 1.0,
|
||||
"estimated_bid_range": [20, 40],
|
||||
}
|
||||
actual = {"winner": "a1", "succeeded": True, "winning_bid": 500}
|
||||
acc = _compute_accuracy(predicted, actual)
|
||||
assert acc < 1.0
|
||||
|
||||
def test_no_actual_bid(self):
|
||||
predicted = {
|
||||
"likely_winner": "a1",
|
||||
"success_probability": 0.7,
|
||||
"estimated_bid_range": [20, 40],
|
||||
}
|
||||
actual = {"winner": "a1", "succeeded": True, "winning_bid": None}
|
||||
acc = _compute_accuracy(predicted, actual)
|
||||
# Bid component skipped — only winner + success
|
||||
assert 0.0 <= acc <= 1.0
|
||||
|
||||
def test_failed_prediction_low_probability(self):
|
||||
predicted = {"success_probability": 0.1}
|
||||
actual = {"succeeded": False}
|
||||
acc = _compute_accuracy(predicted, actual)
|
||||
# Predicted low success and task failed → high accuracy
|
||||
assert acc > 0.8
|
||||
|
||||
|
||||
# ── get_predictions ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetPredictions:
|
||||
def test_empty_db(self):
|
||||
assert get_predictions() == []
|
||||
|
||||
def test_filter_by_task_id(self):
|
||||
predict_task_outcome("t-filter1", "A", ["a1"])
|
||||
predict_task_outcome("t-filter2", "B", ["a2"])
|
||||
preds = get_predictions(task_id="t-filter1")
|
||||
assert len(preds) == 1
|
||||
assert preds[0].task_id == "t-filter1"
|
||||
|
||||
def test_evaluated_only(self):
|
||||
predict_task_outcome("t-eo1", "A", ["a1"])
|
||||
predict_task_outcome("t-eo2", "B", ["a1"])
|
||||
evaluate_prediction("t-eo1", "a1", task_succeeded=True)
|
||||
preds = get_predictions(evaluated_only=True)
|
||||
assert len(preds) == 1
|
||||
assert preds[0].task_id == "t-eo1"
|
||||
|
||||
def test_limit(self):
|
||||
for i in range(10):
|
||||
predict_task_outcome(f"t-lim{i}", "X", ["a1"])
|
||||
preds = get_predictions(limit=3)
|
||||
assert len(preds) == 3
|
||||
|
||||
def test_combined_filters(self):
|
||||
predict_task_outcome("t-combo", "A", ["a1"])
|
||||
evaluate_prediction("t-combo", "a1", task_succeeded=True)
|
||||
predict_task_outcome("t-combo2", "B", ["a1"])
|
||||
preds = get_predictions(task_id="t-combo", evaluated_only=True)
|
||||
assert len(preds) == 1
|
||||
|
||||
def test_order_by_created_desc(self):
|
||||
for i in range(3):
|
||||
predict_task_outcome(f"t-ord{i}", f"Task {i}", ["a1"])
|
||||
preds = get_predictions()
|
||||
# Most recent first
|
||||
assert preds[0].task_id == "t-ord2"
|
||||
|
||||
|
||||
# ── get_accuracy_stats ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetAccuracyStats:
|
||||
def test_empty(self):
|
||||
stats = get_accuracy_stats()
|
||||
assert stats["total_predictions"] == 0
|
||||
assert stats["evaluated"] == 0
|
||||
assert stats["pending"] == 0
|
||||
assert stats["avg_accuracy"] == 0.0
|
||||
assert stats["min_accuracy"] == 0.0
|
||||
assert stats["max_accuracy"] == 0.0
|
||||
|
||||
def test_with_unevaluated(self):
|
||||
predict_task_outcome("t-uneval", "X", ["a1"])
|
||||
stats = get_accuracy_stats()
|
||||
assert stats["total_predictions"] == 1
|
||||
assert stats["evaluated"] == 0
|
||||
assert stats["pending"] == 1
|
||||
|
||||
def test_with_evaluations(self):
|
||||
for i in range(3):
|
||||
predict_task_outcome(f"t-stats{i}", "X", ["a1"])
|
||||
evaluate_prediction(f"t-stats{i}", "a1", task_succeeded=True, winning_bid=30)
|
||||
stats = get_accuracy_stats()
|
||||
assert stats["total_predictions"] == 3
|
||||
assert stats["evaluated"] == 3
|
||||
assert stats["pending"] == 0
|
||||
assert stats["avg_accuracy"] > 0.0
|
||||
assert stats["min_accuracy"] <= stats["avg_accuracy"] <= stats["max_accuracy"]
|
||||
389
tests/spark/test_memory.py
Normal file
389
tests/spark/test_memory.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""Comprehensive tests for spark.memory module.
|
||||
|
||||
Covers:
|
||||
- SparkEvent / SparkMemory dataclasses
|
||||
- _get_conn (schema creation, WAL, busy timeout, idempotent indexes)
|
||||
- score_importance (all event types, boosts, edge cases)
|
||||
- record_event (auto-importance, explicit importance, invalid JSON, swarm bridge)
|
||||
- get_events (all filters, ordering, limit)
|
||||
- count_events (total, by type)
|
||||
- store_memory (with/without expiry)
|
||||
- get_memories (all filters)
|
||||
- count_memories (total, by type)
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from spark.memory import (
|
||||
IMPORTANCE_HIGH,
|
||||
IMPORTANCE_LOW,
|
||||
IMPORTANCE_MEDIUM,
|
||||
SparkEvent,
|
||||
SparkMemory,
|
||||
_get_conn,
|
||||
count_events,
|
||||
count_memories,
|
||||
get_events,
|
||||
get_memories,
|
||||
record_event,
|
||||
score_importance,
|
||||
store_memory,
|
||||
)
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestConstants:
|
||||
def test_importance_ordering(self):
|
||||
assert IMPORTANCE_LOW < IMPORTANCE_MEDIUM < IMPORTANCE_HIGH
|
||||
|
||||
|
||||
# ── Dataclasses ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSparkEventDataclass:
|
||||
def test_all_fields(self):
|
||||
ev = SparkEvent(
|
||||
id="1",
|
||||
event_type="task_posted",
|
||||
agent_id="a1",
|
||||
task_id="t1",
|
||||
description="Test",
|
||||
data="{}",
|
||||
importance=0.5,
|
||||
created_at="2026-01-01",
|
||||
)
|
||||
assert ev.event_type == "task_posted"
|
||||
assert ev.agent_id == "a1"
|
||||
|
||||
def test_nullable_fields(self):
|
||||
ev = SparkEvent(
|
||||
id="2",
|
||||
event_type="task_posted",
|
||||
agent_id=None,
|
||||
task_id=None,
|
||||
description="",
|
||||
data="{}",
|
||||
importance=0.5,
|
||||
created_at="2026-01-01",
|
||||
)
|
||||
assert ev.agent_id is None
|
||||
assert ev.task_id is None
|
||||
|
||||
|
||||
class TestSparkMemoryDataclass:
|
||||
def test_all_fields(self):
|
||||
mem = SparkMemory(
|
||||
id="1",
|
||||
memory_type="pattern",
|
||||
subject="system",
|
||||
content="Test insight",
|
||||
confidence=0.8,
|
||||
source_events=5,
|
||||
created_at="2026-01-01",
|
||||
expires_at="2026-12-31",
|
||||
)
|
||||
assert mem.memory_type == "pattern"
|
||||
assert mem.expires_at == "2026-12-31"
|
||||
|
||||
def test_nullable_expires(self):
|
||||
mem = SparkMemory(
|
||||
id="2",
|
||||
memory_type="anomaly",
|
||||
subject="agent-1",
|
||||
content="Odd behavior",
|
||||
confidence=0.6,
|
||||
source_events=3,
|
||||
created_at="2026-01-01",
|
||||
expires_at=None,
|
||||
)
|
||||
assert mem.expires_at is None
|
||||
|
||||
|
||||
# ── _get_conn ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetConn:
|
||||
def test_creates_tables(self):
|
||||
with _get_conn() as conn:
|
||||
tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
|
||||
names = {r["name"] for r in tables}
|
||||
assert "spark_events" in names
|
||||
assert "spark_memories" in names
|
||||
|
||||
def test_wal_mode(self):
|
||||
with _get_conn() as conn:
|
||||
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
|
||||
assert mode == "wal"
|
||||
|
||||
def test_busy_timeout(self):
|
||||
with _get_conn() as conn:
|
||||
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
|
||||
assert timeout == 5000
|
||||
|
||||
def test_idempotent(self):
|
||||
# Calling _get_conn twice should not raise
|
||||
with _get_conn():
|
||||
pass
|
||||
with _get_conn():
|
||||
pass
|
||||
|
||||
|
||||
# ── score_importance ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestScoreImportance:
|
||||
@pytest.mark.parametrize(
|
||||
"event_type,expected_min,expected_max",
|
||||
[
|
||||
("task_posted", 0.3, 0.5),
|
||||
("bid_submitted", 0.1, 0.3),
|
||||
("task_assigned", 0.4, 0.6),
|
||||
("task_completed", 0.5, 0.7),
|
||||
("task_failed", 0.9, 1.0),
|
||||
("agent_joined", 0.4, 0.6),
|
||||
("prediction_result", 0.6, 0.8),
|
||||
],
|
||||
)
|
||||
def test_base_scores(self, event_type, expected_min, expected_max):
|
||||
score = score_importance(event_type, {})
|
||||
assert expected_min <= score <= expected_max
|
||||
|
||||
def test_unknown_event_default(self):
|
||||
assert score_importance("never_heard_of_this", {}) == 0.5
|
||||
|
||||
def test_failure_boost(self):
|
||||
score = score_importance("task_failed", {})
|
||||
assert score == 1.0
|
||||
|
||||
def test_high_bid_boost(self):
|
||||
low = score_importance("bid_submitted", {"bid_sats": 10})
|
||||
high = score_importance("bid_submitted", {"bid_sats": 100})
|
||||
assert high > low
|
||||
assert high <= 1.0
|
||||
|
||||
def test_high_bid_on_failure(self):
|
||||
score = score_importance("task_failed", {"bid_sats": 100})
|
||||
assert score == 1.0 # capped at 1.0
|
||||
|
||||
def test_score_always_rounded(self):
|
||||
score = score_importance("bid_submitted", {"bid_sats": 100})
|
||||
assert score == round(score, 2)
|
||||
|
||||
|
||||
# ── record_event ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRecordEvent:
|
||||
def test_basic_record(self):
|
||||
eid = record_event("task_posted", "New task", task_id="t1")
|
||||
assert isinstance(eid, str)
|
||||
assert len(eid) > 0
|
||||
|
||||
def test_auto_importance(self):
|
||||
record_event("task_failed", "Failed", task_id="t-auto")
|
||||
events = get_events(task_id="t-auto")
|
||||
assert events[0].importance >= 0.9
|
||||
|
||||
def test_explicit_importance(self):
|
||||
record_event("task_posted", "Custom", task_id="t-expl", importance=0.1)
|
||||
events = get_events(task_id="t-expl")
|
||||
assert events[0].importance == 0.1
|
||||
|
||||
def test_with_agent_and_data(self):
|
||||
data = json.dumps({"bid_sats": 42})
|
||||
record_event("bid_submitted", "Bid", agent_id="a1", task_id="t-data", data=data)
|
||||
events = get_events(task_id="t-data")
|
||||
assert events[0].agent_id == "a1"
|
||||
parsed = json.loads(events[0].data)
|
||||
assert parsed["bid_sats"] == 42
|
||||
|
||||
def test_invalid_json_data_uses_default_importance(self):
|
||||
record_event("task_posted", "Bad data", task_id="t-bad", data="not-json")
|
||||
events = get_events(task_id="t-bad")
|
||||
assert events[0].importance == 0.4 # base for task_posted
|
||||
|
||||
def test_returns_unique_ids(self):
|
||||
id1 = record_event("task_posted", "A")
|
||||
id2 = record_event("task_posted", "B")
|
||||
assert id1 != id2
|
||||
|
||||
|
||||
# ── get_events ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetEvents:
|
||||
def test_empty_db(self):
|
||||
assert get_events() == []
|
||||
|
||||
def test_filter_by_type(self):
|
||||
record_event("task_posted", "A")
|
||||
record_event("task_completed", "B")
|
||||
events = get_events(event_type="task_posted")
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == "task_posted"
|
||||
|
||||
def test_filter_by_agent(self):
|
||||
record_event("task_posted", "A", agent_id="a1")
|
||||
record_event("task_posted", "B", agent_id="a2")
|
||||
events = get_events(agent_id="a1")
|
||||
assert len(events) == 1
|
||||
assert events[0].agent_id == "a1"
|
||||
|
||||
def test_filter_by_task(self):
|
||||
record_event("task_posted", "A", task_id="t1")
|
||||
record_event("task_posted", "B", task_id="t2")
|
||||
events = get_events(task_id="t1")
|
||||
assert len(events) == 1
|
||||
|
||||
def test_filter_by_min_importance(self):
|
||||
record_event("task_posted", "Low", importance=0.1)
|
||||
record_event("task_failed", "High", importance=0.9)
|
||||
events = get_events(min_importance=0.5)
|
||||
assert len(events) == 1
|
||||
assert events[0].importance >= 0.5
|
||||
|
||||
def test_limit(self):
|
||||
for i in range(10):
|
||||
record_event("task_posted", f"ev{i}")
|
||||
events = get_events(limit=3)
|
||||
assert len(events) == 3
|
||||
|
||||
def test_order_by_created_desc(self):
|
||||
record_event("task_posted", "first", task_id="ord1")
|
||||
record_event("task_posted", "second", task_id="ord2")
|
||||
events = get_events()
|
||||
# Most recent first
|
||||
assert events[0].task_id == "ord2"
|
||||
|
||||
def test_combined_filters(self):
|
||||
record_event("task_failed", "A", agent_id="a1", task_id="t1", importance=0.9)
|
||||
record_event("task_posted", "B", agent_id="a1", task_id="t2", importance=0.4)
|
||||
record_event("task_failed", "C", agent_id="a2", task_id="t3", importance=0.9)
|
||||
events = get_events(event_type="task_failed", agent_id="a1", min_importance=0.5)
|
||||
assert len(events) == 1
|
||||
assert events[0].task_id == "t1"
|
||||
|
||||
|
||||
# ── count_events ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCountEvents:
|
||||
def test_empty(self):
|
||||
assert count_events() == 0
|
||||
|
||||
def test_total(self):
|
||||
record_event("task_posted", "A")
|
||||
record_event("task_failed", "B")
|
||||
assert count_events() == 2
|
||||
|
||||
def test_by_type(self):
|
||||
record_event("task_posted", "A")
|
||||
record_event("task_posted", "B")
|
||||
record_event("task_failed", "C")
|
||||
assert count_events("task_posted") == 2
|
||||
assert count_events("task_failed") == 1
|
||||
assert count_events("task_completed") == 0
|
||||
|
||||
|
||||
# ── store_memory ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStoreMemory:
|
||||
def test_basic_store(self):
|
||||
mid = store_memory("pattern", "system", "Test insight")
|
||||
assert isinstance(mid, str)
|
||||
assert len(mid) > 0
|
||||
|
||||
def test_returns_unique_ids(self):
|
||||
id1 = store_memory("pattern", "a", "X")
|
||||
id2 = store_memory("pattern", "b", "Y")
|
||||
assert id1 != id2
|
||||
|
||||
def test_with_all_params(self):
|
||||
store_memory(
|
||||
"anomaly",
|
||||
"agent-1",
|
||||
"Odd pattern",
|
||||
confidence=0.9,
|
||||
source_events=10,
|
||||
expires_at="2026-12-31",
|
||||
)
|
||||
mems = get_memories(subject="agent-1")
|
||||
assert len(mems) == 1
|
||||
assert mems[0].confidence == 0.9
|
||||
assert mems[0].source_events == 10
|
||||
assert mems[0].expires_at == "2026-12-31"
|
||||
|
||||
def test_default_values(self):
|
||||
store_memory("insight", "sys", "Default test")
|
||||
mems = get_memories(subject="sys")
|
||||
assert mems[0].confidence == 0.5
|
||||
assert mems[0].source_events == 0
|
||||
assert mems[0].expires_at is None
|
||||
|
||||
|
||||
# ── get_memories ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetMemories:
|
||||
def test_empty(self):
|
||||
assert get_memories() == []
|
||||
|
||||
def test_filter_by_type(self):
|
||||
store_memory("pattern", "a", "X")
|
||||
store_memory("anomaly", "a", "Y")
|
||||
mems = get_memories(memory_type="pattern")
|
||||
assert len(mems) == 1
|
||||
assert mems[0].memory_type == "pattern"
|
||||
|
||||
def test_filter_by_subject(self):
|
||||
store_memory("pattern", "a", "X")
|
||||
store_memory("pattern", "b", "Y")
|
||||
mems = get_memories(subject="a")
|
||||
assert len(mems) == 1
|
||||
|
||||
def test_filter_by_min_confidence(self):
|
||||
store_memory("pattern", "a", "Low", confidence=0.2)
|
||||
store_memory("pattern", "b", "High", confidence=0.9)
|
||||
mems = get_memories(min_confidence=0.5)
|
||||
assert len(mems) == 1
|
||||
assert mems[0].content == "High"
|
||||
|
||||
def test_limit(self):
|
||||
for i in range(10):
|
||||
store_memory("pattern", "a", f"M{i}")
|
||||
mems = get_memories(limit=3)
|
||||
assert len(mems) == 3
|
||||
|
||||
def test_combined_filters(self):
|
||||
store_memory("pattern", "a", "Target", confidence=0.9)
|
||||
store_memory("anomaly", "a", "Wrong type", confidence=0.9)
|
||||
store_memory("pattern", "b", "Wrong subject", confidence=0.9)
|
||||
store_memory("pattern", "a", "Low conf", confidence=0.1)
|
||||
mems = get_memories(memory_type="pattern", subject="a", min_confidence=0.5)
|
||||
assert len(mems) == 1
|
||||
assert mems[0].content == "Target"
|
||||
|
||||
|
||||
# ── count_memories ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCountMemories:
|
||||
def test_empty(self):
|
||||
assert count_memories() == 0
|
||||
|
||||
def test_total(self):
|
||||
store_memory("pattern", "a", "X")
|
||||
store_memory("anomaly", "b", "Y")
|
||||
assert count_memories() == 2
|
||||
|
||||
def test_by_type(self):
|
||||
store_memory("pattern", "a", "X")
|
||||
store_memory("pattern", "b", "Y")
|
||||
store_memory("anomaly", "c", "Z")
|
||||
assert count_memories("pattern") == 2
|
||||
assert count_memories("anomaly") == 1
|
||||
assert count_memories("insight") == 0
|
||||
Reference in New Issue
Block a user