1
0
This repository has been archived on 2026-03-24. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
Timmy-time-dashboard/tests/unit/test_retrain_loop.py

551 lines
20 KiB
Python

"""Unit tests for the AutoLoRA continuous improvement loop.
Covers trajectory extraction, quality filtering, dataset management,
and the retrain orchestrator.
Refs: #1105
"""
from __future__ import annotations
import json
import tempfile
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
from timmy_automations.retrain.quality_filter import QualityFilter, TrajectoryQuality
from timmy_automations.retrain.retrain import RetrainOrchestrator
from timmy_automations.retrain.training_dataset import TrainingDataset
from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog
from timmy_automations.retrain.trajectory_exporter import Trajectory, TrajectoryExporter
# ── Fixtures ─────────────────────────────────────────────────────────────────
def _ts(offset_minutes: int = 0) -> str:
"""Return an ISO timestamp offset from now."""
return (datetime.now(tz=UTC) + timedelta(minutes=offset_minutes)).isoformat()
def _make_session_log(entries: list[dict], date_str: str, tmp_path: Path) -> Path:
"""Write session JSONL entries to a temp log file."""
log_dir = tmp_path / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"session_{date_str}.jsonl"
with open(log_file, "w") as f:
for entry in entries:
f.write(json.dumps(entry) + "\n")
return log_file
def _user_msg(content: str, offset: int = 0) -> dict:
return {"type": "message", "role": "user", "content": content, "timestamp": _ts(offset)}
def _timmy_msg(content: str, confidence: float | None = None, offset: int = 0) -> dict:
entry = {"type": "message", "role": "timmy", "content": content, "timestamp": _ts(offset)}
if confidence is not None:
entry["confidence"] = confidence
return entry
def _tool_call(tool: str = "bash", result: str = "ok", offset: int = 0) -> dict:
return {
"type": "tool_call",
"tool": tool,
"args": {},
"result": result,
"timestamp": _ts(offset),
}
def _error_entry(msg: str = "Something failed", offset: int = 0) -> dict:
return {"type": "error", "error": msg, "timestamp": _ts(offset)}
def _decision_entry(decision: str = "Use approach A", offset: int = 0) -> dict:
return {"type": "decision", "decision": decision, "timestamp": _ts(offset)}
# ── Trajectory dataclass tests ────────────────────────────────────────────────
class TestTrajectory:
def test_message_count(self):
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_user_msg("hi"), _timmy_msg("hello")],
)
assert t.message_count == 2
def test_tool_call_count(self):
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
tool_calls=[_tool_call(), _tool_call()],
)
assert t.tool_call_count == 2
def test_has_successful_tool_call_when_no_errors(self):
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
tool_calls=[_tool_call()],
errors=[],
)
assert t.has_successful_tool_call is True
def test_has_successful_tool_call_false_when_errors(self):
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
tool_calls=[_tool_call()],
errors=[_error_entry()],
)
assert t.has_successful_tool_call is False
def test_is_multi_step(self):
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_user_msg("do it"), _timmy_msg("done")],
tool_calls=[_tool_call()],
)
assert t.is_multi_step is True
def test_is_not_multi_step_single_message(self):
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_timmy_msg("hello")],
tool_calls=[],
)
assert t.is_multi_step is False
def test_to_chat_format_ordering(self):
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_user_msg("question", offset=0), _timmy_msg("answer", offset=2)],
tool_calls=[_tool_call(offset=1)],
)
chat = t.to_chat_format()
roles = [m["role"] for m in chat]
assert "user" in roles
assert "assistant" in roles
def test_to_chat_format_empty_content_skipped(self):
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_user_msg(""), _timmy_msg("response")],
)
chat = t.to_chat_format()
# Empty user message should be skipped
assert all(m["content"] for m in chat)
# ── TrajectoryExporter tests ──────────────────────────────────────────────────
class TestTrajectoryExporter:
def test_export_empty_logs_dir(self, tmp_path):
(tmp_path / "logs").mkdir()
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
result = exporter.export_week(weeks_ago=0)
assert result == []
def test_export_reads_session_files(self, tmp_path):
# Write a session file for this week
today = datetime.now(tz=UTC)
date_str = today.strftime("%Y-%m-%d")
entries = [
_user_msg("tell me about Python"),
_timmy_msg("Python is great"),
]
_make_session_log(entries, date_str, tmp_path)
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
result = exporter.export_week(weeks_ago=0)
assert len(result) >= 1
def test_export_skips_old_sessions(self, tmp_path):
# Write a session file for 3 weeks ago
three_weeks_ago = datetime.now(tz=UTC) - timedelta(weeks=3)
date_str = three_weeks_ago.strftime("%Y-%m-%d")
entries = [_user_msg("old message"), _timmy_msg("old response")]
_make_session_log(entries, date_str, tmp_path)
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
# Request current week — should not include 3-week-old data
result = exporter.export_week(weeks_ago=0)
assert result == []
def test_export_segments_by_gap(self, tmp_path):
today = datetime.now(tz=UTC)
date_str = today.strftime("%Y-%m-%d")
# Two conversations separated by 10 minutes
t1 = (today - timedelta(minutes=15)).isoformat()
t2 = (today - timedelta(minutes=14)).isoformat()
t3 = (today - timedelta(minutes=2)).isoformat()
t4 = (today - timedelta(minutes=1)).isoformat()
entries = [
{"type": "message", "role": "user", "content": "first q", "timestamp": t1},
{"type": "message", "role": "timmy", "content": "first a", "timestamp": t2},
{"type": "message", "role": "user", "content": "second q", "timestamp": t3},
{"type": "message", "role": "timmy", "content": "second a", "timestamp": t4},
]
_make_session_log(entries, date_str, tmp_path)
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
result = exporter.export_week(weeks_ago=0)
# Should have at least 1 trajectory (may be 1 or 2 depending on segmentation)
assert len(result) >= 1
def test_handles_malformed_log_file(self, tmp_path):
log_dir = tmp_path / "logs"
log_dir.mkdir()
today = datetime.now(tz=UTC).strftime("%Y-%m-%d")
(log_dir / f"session_{today}.jsonl").write_text("not json\n{}\n")
exporter = TrajectoryExporter(logs_dir=log_dir, repo_root=tmp_path)
# Should not raise, just return empty or partial results
result = exporter.export_week(weeks_ago=0)
assert isinstance(result, list)
# ── QualityFilter tests ───────────────────────────────────────────────────────
class TestQualityFilter:
def _make_high_quality(self) -> Trajectory:
return Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_user_msg("do task"), _timmy_msg("done", confidence=0.9)],
tool_calls=[_tool_call(), _tool_call()],
errors=[],
decisions=[_decision_entry()],
)
def _make_medium_quality(self) -> Trajectory:
return Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_user_msg("hello"), _timmy_msg("hi")],
tool_calls=[],
errors=[],
)
def _make_low_quality(self) -> Trajectory:
return Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_timmy_msg("oops")], # No user message
errors=[_error_entry()],
)
def test_high_quality_classification(self):
qf = QualityFilter()
result = qf.assess(self._make_high_quality())
assert result.quality == TrajectoryQuality.HIGH
assert result.score >= 4.0
assert result.is_trainable
def test_medium_quality_classification(self):
qf = QualityFilter()
result = qf.assess(self._make_medium_quality())
assert result.quality == TrajectoryQuality.MEDIUM
assert result.is_trainable
def test_low_quality_no_user_message(self):
qf = QualityFilter()
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_timmy_msg("random")],
)
result = qf.assess(t)
assert result.quality == TrajectoryQuality.LOW
assert not result.is_trainable
def test_error_penalizes_score(self):
qf = QualityFilter()
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_user_msg("go"), _timmy_msg("fail")],
tool_calls=[_tool_call()],
errors=[_error_entry(), _error_entry()],
)
result = qf.assess(t)
assert result.score < qf.assess(self._make_high_quality()).score
def test_low_confidence_penalizes_score(self):
qf = QualityFilter()
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(),
ended_at=_ts(),
messages=[_user_msg("q"), _timmy_msg("a", confidence=0.2)],
)
result = qf.assess(t)
assert result.score < 1.0
def test_filter_returns_stats(self):
qf = QualityFilter()
trajectories = [
self._make_high_quality(),
self._make_medium_quality(),
self._make_low_quality(),
]
trainable, stats = qf.filter(trajectories)
assert stats["total"] == 3
assert stats["accepted"] == len(trainable)
assert stats["high"] + stats["medium"] + stats["low"] == 3
def test_filter_empty_list(self):
qf = QualityFilter()
trainable, stats = qf.filter([])
assert trainable == []
assert stats["total"] == 0
assert stats["accepted"] == 0
# ── TrainingDataset tests ─────────────────────────────────────────────────────
class TestTrainingDataset:
def _make_result(self, quality=TrajectoryQuality.HIGH, score=5.0) -> object:
from timmy_automations.retrain.quality_filter import QualityResult
t = Trajectory(
session_date="2026-03-17",
started_at=_ts(-5),
ended_at=_ts(),
messages=[_user_msg("do it"), _timmy_msg("done")],
tool_calls=[_tool_call()],
)
return QualityResult(trajectory=t, quality=quality, score=score, reasons=[])
def test_count_empty_dataset(self, tmp_path):
ds = TrainingDataset(
dataset_path=".loop/retrain/training_data.jsonl",
repo_root=tmp_path,
)
assert ds.count() == 0
def test_append_adds_examples(self, tmp_path):
ds = TrainingDataset(repo_root=tmp_path)
result = ds.append([self._make_result()], "2026-W12")
assert result.new_examples == 1
assert result.total_examples == 1
assert ds.count() == 1
def test_append_idempotent(self, tmp_path):
ds = TrainingDataset(repo_root=tmp_path)
r = self._make_result()
ds.append([r], "2026-W12")
result2 = ds.append([r], "2026-W12")
# Same trajectory shouldn't be added twice
assert result2.new_examples == 0
assert ds.count() == 1
def test_append_different_weeks(self, tmp_path):
ds = TrainingDataset(repo_root=tmp_path)
r1 = self._make_result()
ds.append([r1], "2026-W11")
ds.append([r1], "2026-W12")
# Different week tags = different records
assert ds.count() == 2
def test_dataset_file_is_valid_jsonl(self, tmp_path):
ds = TrainingDataset(repo_root=tmp_path)
ds.append([self._make_result()], "2026-W12")
with open(ds.dataset_path) as f:
lines = [l.strip() for l in f if l.strip()]
assert len(lines) == 1
record = json.loads(lines[0])
assert "messages" in record
assert "week" in record
assert "quality" in record
def test_index_updated_after_append(self, tmp_path):
ds = TrainingDataset(repo_root=tmp_path)
ds.append([self._make_result()], "2026-W12")
index_path = tmp_path / ".loop" / "retrain" / "dataset_index.json"
assert index_path.exists()
index = json.loads(index_path.read_text())
assert index["total_examples"] == 1
assert "2026-W12" in index["weeks"]
# ── TrainingLog tests ─────────────────────────────────────────────────────────
class TestTrainingLog:
def _make_metrics(self, iteration: int = 1) -> CycleMetrics:
return CycleMetrics(
iteration=iteration,
week="2026-W12",
ran_at=datetime.now(tz=UTC).isoformat(),
trajectories_total=10,
trajectories_high=5,
trajectories_medium=3,
trajectories_low=2,
trajectories_accepted=8,
examples_added=5,
dataset_total=5,
train_status="completed",
train_loss=1.2345,
train_duration_seconds=120.5,
adapter_path=".loop/retrain/adapters/iter_0001/adapters.npz",
model_name="hermes4-14b-ft-0001",
notes="First fine-tune cycle complete",
)
def test_next_iteration_starts_at_1(self, tmp_path):
log = TrainingLog(repo_root=tmp_path)
assert log.next_iteration() == 1
def test_next_iteration_increments(self, tmp_path):
log = TrainingLog(repo_root=tmp_path)
log.record(self._make_metrics(iteration=1))
assert log.next_iteration() == 2
def test_record_creates_log_file(self, tmp_path):
log = TrainingLog(repo_root=tmp_path)
log.record(self._make_metrics())
assert log.log_path.exists()
def test_load_all_returns_records(self, tmp_path):
log = TrainingLog(repo_root=tmp_path)
log.record(self._make_metrics(iteration=1))
log.record(self._make_metrics(iteration=2))
entries = log.load_all()
assert len(entries) == 2
assert entries[0]["iteration"] == 1
def test_latest_returns_last_entry(self, tmp_path):
log = TrainingLog(repo_root=tmp_path)
log.record(self._make_metrics(iteration=1))
log.record(self._make_metrics(iteration=2))
latest = log.latest()
assert latest is not None
assert latest["iteration"] == 2
def test_latest_returns_none_when_empty(self, tmp_path):
log = TrainingLog(repo_root=tmp_path)
assert log.latest() is None
def test_summary_markdown_written(self, tmp_path):
log = TrainingLog(repo_root=tmp_path)
log.record(self._make_metrics())
summary_path = tmp_path / ".loop" / "retrain" / "training_log.md"
assert summary_path.exists()
content = summary_path.read_text()
assert "AutoLoRA Training Log" in content
assert "2026-W12" in content
assert "completed" in content
def test_skill_accuracy_in_summary(self, tmp_path):
log = TrainingLog(repo_root=tmp_path)
m = self._make_metrics()
m.skill_accuracy = {"tool_calling": 0.85, "reasoning": 0.72}
log.record(m)
content = (tmp_path / ".loop" / "retrain" / "training_log.md").read_text()
assert "tool_calling" in content
assert "reasoning" in content
# ── RetrainOrchestrator integration tests ─────────────────────────────────────
class TestRetrainOrchestrator:
def test_run_dry_run_no_data(self, tmp_path):
"""Dry run with no session logs should complete without errors."""
(tmp_path / "logs").mkdir(parents=True)
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
result = orc.run(weeks_ago=0)
assert result.train_status in ("skipped",)
assert result.examples_added == 0
assert result.iteration == 1
def test_run_creates_log_entry(self, tmp_path):
(tmp_path / "logs").mkdir(parents=True)
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
orc.run(weeks_ago=0)
log = TrainingLog(repo_root=tmp_path)
entries = log.load_all()
assert len(entries) == 1
def test_run_with_session_data(self, tmp_path):
"""Run with actual session data — should export, filter, and log."""
today = datetime.now(tz=UTC)
date_str = today.strftime("%Y-%m-%d")
entries = [
_user_msg("deploy the service", offset=-10),
_tool_call("bash", "deployed successfully", offset=-9),
_tool_call("bash", "health check ok", offset=-8),
_timmy_msg("Service deployed and healthy", confidence=0.92, offset=-7),
_user_msg("run the tests", offset=-6),
_tool_call("bash", "All tests passed", offset=-5),
_timmy_msg("All 42 tests passed", confidence=0.95, offset=-4),
]
_make_session_log(entries, date_str, tmp_path)
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
result = orc.run(weeks_ago=0)
assert result.trajectories_exported >= 1
assert result.iteration == 1
# In dry_run mode, fine-tune is skipped but trajectories should be processed
assert result.train_status == "skipped"
def test_iteration_increments_on_second_run(self, tmp_path):
(tmp_path / "logs").mkdir(parents=True)
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
r1 = orc.run(weeks_ago=0)
r2 = orc.run(weeks_ago=0)
assert r2.iteration == r1.iteration + 1
def test_automations_json_has_retrain_entry(self):
"""Verify the retrain automation is registered in automations.json."""
config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json"
assert config_path.exists()
manifest = json.loads(config_path.read_text())
ids = [a["id"] for a in manifest.get("automations", [])]
assert "retrain" in ids
def test_retrain_automation_config(self):
"""Verify retrain automation has correct schedule and config."""
config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json"
manifest = json.loads(config_path.read_text())
retrain = next(a for a in manifest["automations"] if a["id"] == "retrain")
assert retrain["schedule"] == "weekly_sunday"
assert retrain["trigger"] == "scheduled"
assert retrain["config"]["base_model"] == "hermes4-14b"
assert retrain["config"]["weeks_ago"] == 1
_REPO_ROOT = Path(__file__).resolve().parent.parent.parent