551 lines
20 KiB
Python
551 lines
20 KiB
Python
"""Unit tests for the AutoLoRA continuous improvement loop.
|
|
|
|
Covers trajectory extraction, quality filtering, dataset management,
|
|
and the retrain orchestrator.
|
|
|
|
Refs: #1105
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import tempfile
|
|
from datetime import UTC, datetime, timedelta
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from timmy_automations.retrain.quality_filter import QualityFilter, TrajectoryQuality
|
|
from timmy_automations.retrain.retrain import RetrainOrchestrator
|
|
from timmy_automations.retrain.training_dataset import TrainingDataset
|
|
from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog
|
|
from timmy_automations.retrain.trajectory_exporter import Trajectory, TrajectoryExporter
|
|
|
|
|
|
# ── Fixtures ─────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _ts(offset_minutes: int = 0) -> str:
|
|
"""Return an ISO timestamp offset from now."""
|
|
return (datetime.now(tz=UTC) + timedelta(minutes=offset_minutes)).isoformat()
|
|
|
|
|
|
def _make_session_log(entries: list[dict], date_str: str, tmp_path: Path) -> Path:
|
|
"""Write session JSONL entries to a temp log file."""
|
|
log_dir = tmp_path / "logs"
|
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
log_file = log_dir / f"session_{date_str}.jsonl"
|
|
with open(log_file, "w") as f:
|
|
for entry in entries:
|
|
f.write(json.dumps(entry) + "\n")
|
|
return log_file
|
|
|
|
|
|
def _user_msg(content: str, offset: int = 0) -> dict:
|
|
return {"type": "message", "role": "user", "content": content, "timestamp": _ts(offset)}
|
|
|
|
|
|
def _timmy_msg(content: str, confidence: float | None = None, offset: int = 0) -> dict:
|
|
entry = {"type": "message", "role": "timmy", "content": content, "timestamp": _ts(offset)}
|
|
if confidence is not None:
|
|
entry["confidence"] = confidence
|
|
return entry
|
|
|
|
|
|
def _tool_call(tool: str = "bash", result: str = "ok", offset: int = 0) -> dict:
|
|
return {
|
|
"type": "tool_call",
|
|
"tool": tool,
|
|
"args": {},
|
|
"result": result,
|
|
"timestamp": _ts(offset),
|
|
}
|
|
|
|
|
|
def _error_entry(msg: str = "Something failed", offset: int = 0) -> dict:
|
|
return {"type": "error", "error": msg, "timestamp": _ts(offset)}
|
|
|
|
|
|
def _decision_entry(decision: str = "Use approach A", offset: int = 0) -> dict:
|
|
return {"type": "decision", "decision": decision, "timestamp": _ts(offset)}
|
|
|
|
|
|
# ── Trajectory dataclass tests ────────────────────────────────────────────────
|
|
|
|
|
|
class TestTrajectory:
|
|
def test_message_count(self):
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_user_msg("hi"), _timmy_msg("hello")],
|
|
)
|
|
assert t.message_count == 2
|
|
|
|
def test_tool_call_count(self):
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
tool_calls=[_tool_call(), _tool_call()],
|
|
)
|
|
assert t.tool_call_count == 2
|
|
|
|
def test_has_successful_tool_call_when_no_errors(self):
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
tool_calls=[_tool_call()],
|
|
errors=[],
|
|
)
|
|
assert t.has_successful_tool_call is True
|
|
|
|
def test_has_successful_tool_call_false_when_errors(self):
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
tool_calls=[_tool_call()],
|
|
errors=[_error_entry()],
|
|
)
|
|
assert t.has_successful_tool_call is False
|
|
|
|
def test_is_multi_step(self):
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_user_msg("do it"), _timmy_msg("done")],
|
|
tool_calls=[_tool_call()],
|
|
)
|
|
assert t.is_multi_step is True
|
|
|
|
def test_is_not_multi_step_single_message(self):
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_timmy_msg("hello")],
|
|
tool_calls=[],
|
|
)
|
|
assert t.is_multi_step is False
|
|
|
|
def test_to_chat_format_ordering(self):
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_user_msg("question", offset=0), _timmy_msg("answer", offset=2)],
|
|
tool_calls=[_tool_call(offset=1)],
|
|
)
|
|
chat = t.to_chat_format()
|
|
roles = [m["role"] for m in chat]
|
|
assert "user" in roles
|
|
assert "assistant" in roles
|
|
|
|
def test_to_chat_format_empty_content_skipped(self):
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_user_msg(""), _timmy_msg("response")],
|
|
)
|
|
chat = t.to_chat_format()
|
|
# Empty user message should be skipped
|
|
assert all(m["content"] for m in chat)
|
|
|
|
|
|
# ── TrajectoryExporter tests ──────────────────────────────────────────────────
|
|
|
|
|
|
class TestTrajectoryExporter:
|
|
def test_export_empty_logs_dir(self, tmp_path):
|
|
(tmp_path / "logs").mkdir()
|
|
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
|
result = exporter.export_week(weeks_ago=0)
|
|
assert result == []
|
|
|
|
def test_export_reads_session_files(self, tmp_path):
|
|
# Write a session file for this week
|
|
today = datetime.now(tz=UTC)
|
|
date_str = today.strftime("%Y-%m-%d")
|
|
entries = [
|
|
_user_msg("tell me about Python"),
|
|
_timmy_msg("Python is great"),
|
|
]
|
|
_make_session_log(entries, date_str, tmp_path)
|
|
|
|
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
|
result = exporter.export_week(weeks_ago=0)
|
|
assert len(result) >= 1
|
|
|
|
def test_export_skips_old_sessions(self, tmp_path):
|
|
# Write a session file for 3 weeks ago
|
|
three_weeks_ago = datetime.now(tz=UTC) - timedelta(weeks=3)
|
|
date_str = three_weeks_ago.strftime("%Y-%m-%d")
|
|
entries = [_user_msg("old message"), _timmy_msg("old response")]
|
|
_make_session_log(entries, date_str, tmp_path)
|
|
|
|
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
|
# Request current week — should not include 3-week-old data
|
|
result = exporter.export_week(weeks_ago=0)
|
|
assert result == []
|
|
|
|
def test_export_segments_by_gap(self, tmp_path):
|
|
today = datetime.now(tz=UTC)
|
|
date_str = today.strftime("%Y-%m-%d")
|
|
|
|
# Two conversations separated by 10 minutes
|
|
t1 = (today - timedelta(minutes=15)).isoformat()
|
|
t2 = (today - timedelta(minutes=14)).isoformat()
|
|
t3 = (today - timedelta(minutes=2)).isoformat()
|
|
t4 = (today - timedelta(minutes=1)).isoformat()
|
|
|
|
entries = [
|
|
{"type": "message", "role": "user", "content": "first q", "timestamp": t1},
|
|
{"type": "message", "role": "timmy", "content": "first a", "timestamp": t2},
|
|
{"type": "message", "role": "user", "content": "second q", "timestamp": t3},
|
|
{"type": "message", "role": "timmy", "content": "second a", "timestamp": t4},
|
|
]
|
|
_make_session_log(entries, date_str, tmp_path)
|
|
|
|
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
|
result = exporter.export_week(weeks_ago=0)
|
|
# Should have at least 1 trajectory (may be 1 or 2 depending on segmentation)
|
|
assert len(result) >= 1
|
|
|
|
def test_handles_malformed_log_file(self, tmp_path):
|
|
log_dir = tmp_path / "logs"
|
|
log_dir.mkdir()
|
|
today = datetime.now(tz=UTC).strftime("%Y-%m-%d")
|
|
(log_dir / f"session_{today}.jsonl").write_text("not json\n{}\n")
|
|
|
|
exporter = TrajectoryExporter(logs_dir=log_dir, repo_root=tmp_path)
|
|
# Should not raise, just return empty or partial results
|
|
result = exporter.export_week(weeks_ago=0)
|
|
assert isinstance(result, list)
|
|
|
|
|
|
# ── QualityFilter tests ───────────────────────────────────────────────────────
|
|
|
|
|
|
class TestQualityFilter:
|
|
def _make_high_quality(self) -> Trajectory:
|
|
return Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_user_msg("do task"), _timmy_msg("done", confidence=0.9)],
|
|
tool_calls=[_tool_call(), _tool_call()],
|
|
errors=[],
|
|
decisions=[_decision_entry()],
|
|
)
|
|
|
|
def _make_medium_quality(self) -> Trajectory:
|
|
return Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_user_msg("hello"), _timmy_msg("hi")],
|
|
tool_calls=[],
|
|
errors=[],
|
|
)
|
|
|
|
def _make_low_quality(self) -> Trajectory:
|
|
return Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_timmy_msg("oops")], # No user message
|
|
errors=[_error_entry()],
|
|
)
|
|
|
|
def test_high_quality_classification(self):
|
|
qf = QualityFilter()
|
|
result = qf.assess(self._make_high_quality())
|
|
assert result.quality == TrajectoryQuality.HIGH
|
|
assert result.score >= 4.0
|
|
assert result.is_trainable
|
|
|
|
def test_medium_quality_classification(self):
|
|
qf = QualityFilter()
|
|
result = qf.assess(self._make_medium_quality())
|
|
assert result.quality == TrajectoryQuality.MEDIUM
|
|
assert result.is_trainable
|
|
|
|
def test_low_quality_no_user_message(self):
|
|
qf = QualityFilter()
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_timmy_msg("random")],
|
|
)
|
|
result = qf.assess(t)
|
|
assert result.quality == TrajectoryQuality.LOW
|
|
assert not result.is_trainable
|
|
|
|
def test_error_penalizes_score(self):
|
|
qf = QualityFilter()
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_user_msg("go"), _timmy_msg("fail")],
|
|
tool_calls=[_tool_call()],
|
|
errors=[_error_entry(), _error_entry()],
|
|
)
|
|
result = qf.assess(t)
|
|
assert result.score < qf.assess(self._make_high_quality()).score
|
|
|
|
def test_low_confidence_penalizes_score(self):
|
|
qf = QualityFilter()
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(),
|
|
ended_at=_ts(),
|
|
messages=[_user_msg("q"), _timmy_msg("a", confidence=0.2)],
|
|
)
|
|
result = qf.assess(t)
|
|
assert result.score < 1.0
|
|
|
|
def test_filter_returns_stats(self):
|
|
qf = QualityFilter()
|
|
trajectories = [
|
|
self._make_high_quality(),
|
|
self._make_medium_quality(),
|
|
self._make_low_quality(),
|
|
]
|
|
trainable, stats = qf.filter(trajectories)
|
|
assert stats["total"] == 3
|
|
assert stats["accepted"] == len(trainable)
|
|
assert stats["high"] + stats["medium"] + stats["low"] == 3
|
|
|
|
def test_filter_empty_list(self):
|
|
qf = QualityFilter()
|
|
trainable, stats = qf.filter([])
|
|
assert trainable == []
|
|
assert stats["total"] == 0
|
|
assert stats["accepted"] == 0
|
|
|
|
|
|
# ── TrainingDataset tests ─────────────────────────────────────────────────────
|
|
|
|
|
|
class TestTrainingDataset:
|
|
def _make_result(self, quality=TrajectoryQuality.HIGH, score=5.0) -> object:
|
|
from timmy_automations.retrain.quality_filter import QualityResult
|
|
|
|
t = Trajectory(
|
|
session_date="2026-03-17",
|
|
started_at=_ts(-5),
|
|
ended_at=_ts(),
|
|
messages=[_user_msg("do it"), _timmy_msg("done")],
|
|
tool_calls=[_tool_call()],
|
|
)
|
|
return QualityResult(trajectory=t, quality=quality, score=score, reasons=[])
|
|
|
|
def test_count_empty_dataset(self, tmp_path):
|
|
ds = TrainingDataset(
|
|
dataset_path=".loop/retrain/training_data.jsonl",
|
|
repo_root=tmp_path,
|
|
)
|
|
assert ds.count() == 0
|
|
|
|
def test_append_adds_examples(self, tmp_path):
|
|
ds = TrainingDataset(repo_root=tmp_path)
|
|
result = ds.append([self._make_result()], "2026-W12")
|
|
assert result.new_examples == 1
|
|
assert result.total_examples == 1
|
|
assert ds.count() == 1
|
|
|
|
def test_append_idempotent(self, tmp_path):
|
|
ds = TrainingDataset(repo_root=tmp_path)
|
|
r = self._make_result()
|
|
ds.append([r], "2026-W12")
|
|
result2 = ds.append([r], "2026-W12")
|
|
# Same trajectory shouldn't be added twice
|
|
assert result2.new_examples == 0
|
|
assert ds.count() == 1
|
|
|
|
def test_append_different_weeks(self, tmp_path):
|
|
ds = TrainingDataset(repo_root=tmp_path)
|
|
r1 = self._make_result()
|
|
ds.append([r1], "2026-W11")
|
|
ds.append([r1], "2026-W12")
|
|
# Different week tags = different records
|
|
assert ds.count() == 2
|
|
|
|
def test_dataset_file_is_valid_jsonl(self, tmp_path):
|
|
ds = TrainingDataset(repo_root=tmp_path)
|
|
ds.append([self._make_result()], "2026-W12")
|
|
with open(ds.dataset_path) as f:
|
|
lines = [l.strip() for l in f if l.strip()]
|
|
assert len(lines) == 1
|
|
record = json.loads(lines[0])
|
|
assert "messages" in record
|
|
assert "week" in record
|
|
assert "quality" in record
|
|
|
|
def test_index_updated_after_append(self, tmp_path):
|
|
ds = TrainingDataset(repo_root=tmp_path)
|
|
ds.append([self._make_result()], "2026-W12")
|
|
index_path = tmp_path / ".loop" / "retrain" / "dataset_index.json"
|
|
assert index_path.exists()
|
|
index = json.loads(index_path.read_text())
|
|
assert index["total_examples"] == 1
|
|
assert "2026-W12" in index["weeks"]
|
|
|
|
|
|
# ── TrainingLog tests ─────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestTrainingLog:
|
|
def _make_metrics(self, iteration: int = 1) -> CycleMetrics:
|
|
return CycleMetrics(
|
|
iteration=iteration,
|
|
week="2026-W12",
|
|
ran_at=datetime.now(tz=UTC).isoformat(),
|
|
trajectories_total=10,
|
|
trajectories_high=5,
|
|
trajectories_medium=3,
|
|
trajectories_low=2,
|
|
trajectories_accepted=8,
|
|
examples_added=5,
|
|
dataset_total=5,
|
|
train_status="completed",
|
|
train_loss=1.2345,
|
|
train_duration_seconds=120.5,
|
|
adapter_path=".loop/retrain/adapters/iter_0001/adapters.npz",
|
|
model_name="hermes4-14b-ft-0001",
|
|
notes="First fine-tune cycle complete",
|
|
)
|
|
|
|
def test_next_iteration_starts_at_1(self, tmp_path):
|
|
log = TrainingLog(repo_root=tmp_path)
|
|
assert log.next_iteration() == 1
|
|
|
|
def test_next_iteration_increments(self, tmp_path):
|
|
log = TrainingLog(repo_root=tmp_path)
|
|
log.record(self._make_metrics(iteration=1))
|
|
assert log.next_iteration() == 2
|
|
|
|
def test_record_creates_log_file(self, tmp_path):
|
|
log = TrainingLog(repo_root=tmp_path)
|
|
log.record(self._make_metrics())
|
|
assert log.log_path.exists()
|
|
|
|
def test_load_all_returns_records(self, tmp_path):
|
|
log = TrainingLog(repo_root=tmp_path)
|
|
log.record(self._make_metrics(iteration=1))
|
|
log.record(self._make_metrics(iteration=2))
|
|
entries = log.load_all()
|
|
assert len(entries) == 2
|
|
assert entries[0]["iteration"] == 1
|
|
|
|
def test_latest_returns_last_entry(self, tmp_path):
|
|
log = TrainingLog(repo_root=tmp_path)
|
|
log.record(self._make_metrics(iteration=1))
|
|
log.record(self._make_metrics(iteration=2))
|
|
latest = log.latest()
|
|
assert latest is not None
|
|
assert latest["iteration"] == 2
|
|
|
|
def test_latest_returns_none_when_empty(self, tmp_path):
|
|
log = TrainingLog(repo_root=tmp_path)
|
|
assert log.latest() is None
|
|
|
|
def test_summary_markdown_written(self, tmp_path):
|
|
log = TrainingLog(repo_root=tmp_path)
|
|
log.record(self._make_metrics())
|
|
summary_path = tmp_path / ".loop" / "retrain" / "training_log.md"
|
|
assert summary_path.exists()
|
|
content = summary_path.read_text()
|
|
assert "AutoLoRA Training Log" in content
|
|
assert "2026-W12" in content
|
|
assert "completed" in content
|
|
|
|
def test_skill_accuracy_in_summary(self, tmp_path):
|
|
log = TrainingLog(repo_root=tmp_path)
|
|
m = self._make_metrics()
|
|
m.skill_accuracy = {"tool_calling": 0.85, "reasoning": 0.72}
|
|
log.record(m)
|
|
content = (tmp_path / ".loop" / "retrain" / "training_log.md").read_text()
|
|
assert "tool_calling" in content
|
|
assert "reasoning" in content
|
|
|
|
|
|
# ── RetrainOrchestrator integration tests ─────────────────────────────────────
|
|
|
|
|
|
class TestRetrainOrchestrator:
|
|
def test_run_dry_run_no_data(self, tmp_path):
|
|
"""Dry run with no session logs should complete without errors."""
|
|
(tmp_path / "logs").mkdir(parents=True)
|
|
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
|
result = orc.run(weeks_ago=0)
|
|
assert result.train_status in ("skipped",)
|
|
assert result.examples_added == 0
|
|
assert result.iteration == 1
|
|
|
|
def test_run_creates_log_entry(self, tmp_path):
|
|
(tmp_path / "logs").mkdir(parents=True)
|
|
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
|
orc.run(weeks_ago=0)
|
|
log = TrainingLog(repo_root=tmp_path)
|
|
entries = log.load_all()
|
|
assert len(entries) == 1
|
|
|
|
def test_run_with_session_data(self, tmp_path):
|
|
"""Run with actual session data — should export, filter, and log."""
|
|
today = datetime.now(tz=UTC)
|
|
date_str = today.strftime("%Y-%m-%d")
|
|
entries = [
|
|
_user_msg("deploy the service", offset=-10),
|
|
_tool_call("bash", "deployed successfully", offset=-9),
|
|
_tool_call("bash", "health check ok", offset=-8),
|
|
_timmy_msg("Service deployed and healthy", confidence=0.92, offset=-7),
|
|
_user_msg("run the tests", offset=-6),
|
|
_tool_call("bash", "All tests passed", offset=-5),
|
|
_timmy_msg("All 42 tests passed", confidence=0.95, offset=-4),
|
|
]
|
|
_make_session_log(entries, date_str, tmp_path)
|
|
|
|
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
|
result = orc.run(weeks_ago=0)
|
|
|
|
assert result.trajectories_exported >= 1
|
|
assert result.iteration == 1
|
|
# In dry_run mode, fine-tune is skipped but trajectories should be processed
|
|
assert result.train_status == "skipped"
|
|
|
|
def test_iteration_increments_on_second_run(self, tmp_path):
|
|
(tmp_path / "logs").mkdir(parents=True)
|
|
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
|
r1 = orc.run(weeks_ago=0)
|
|
r2 = orc.run(weeks_ago=0)
|
|
assert r2.iteration == r1.iteration + 1
|
|
|
|
def test_automations_json_has_retrain_entry(self):
|
|
"""Verify the retrain automation is registered in automations.json."""
|
|
config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json"
|
|
assert config_path.exists()
|
|
manifest = json.loads(config_path.read_text())
|
|
ids = [a["id"] for a in manifest.get("automations", [])]
|
|
assert "retrain" in ids
|
|
|
|
def test_retrain_automation_config(self):
|
|
"""Verify retrain automation has correct schedule and config."""
|
|
config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json"
|
|
manifest = json.loads(config_path.read_text())
|
|
retrain = next(a for a in manifest["automations"] if a["id"] == "retrain")
|
|
assert retrain["schedule"] == "weekly_sunday"
|
|
assert retrain["trigger"] == "scheduled"
|
|
assert retrain["config"]["base_model"] == "hermes4-14b"
|
|
assert retrain["config"]["weeks_ago"] == 1
|
|
|
|
|
|
_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|