diff --git a/tests/unit/test_retrain_loop.py b/tests/unit/test_retrain_loop.py new file mode 100644 index 00000000..313b50cd --- /dev/null +++ b/tests/unit/test_retrain_loop.py @@ -0,0 +1,550 @@ +"""Unit tests for the AutoLoRA continuous improvement loop. + +Covers trajectory extraction, quality filtering, dataset management, +and the retrain orchestrator. + +Refs: #1105 +""" + +from __future__ import annotations + +import json +import tempfile +from datetime import UTC, datetime, timedelta +from pathlib import Path + +import pytest + +from timmy_automations.retrain.quality_filter import QualityFilter, TrajectoryQuality +from timmy_automations.retrain.retrain import RetrainOrchestrator +from timmy_automations.retrain.training_dataset import TrainingDataset +from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog +from timmy_automations.retrain.trajectory_exporter import Trajectory, TrajectoryExporter + + +# ── Fixtures ───────────────────────────────────────────────────────────────── + + +def _ts(offset_minutes: int = 0) -> str: + """Return an ISO timestamp offset from now.""" + return (datetime.now(tz=UTC) + timedelta(minutes=offset_minutes)).isoformat() + + +def _make_session_log(entries: list[dict], date_str: str, tmp_path: Path) -> Path: + """Write session JSONL entries to a temp log file.""" + log_dir = tmp_path / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / f"session_{date_str}.jsonl" + with open(log_file, "w") as f: + for entry in entries: + f.write(json.dumps(entry) + "\n") + return log_file + + +def _user_msg(content: str, offset: int = 0) -> dict: + return {"type": "message", "role": "user", "content": content, "timestamp": _ts(offset)} + + +def _timmy_msg(content: str, confidence: float | None = None, offset: int = 0) -> dict: + entry = {"type": "message", "role": "timmy", "content": content, "timestamp": _ts(offset)} + if confidence is not None: + entry["confidence"] = confidence + return entry + + +def _tool_call(tool: str = "bash", result: str = "ok", offset: int = 0) -> dict: + return { + "type": "tool_call", + "tool": tool, + "args": {}, + "result": result, + "timestamp": _ts(offset), + } + + +def _error_entry(msg: str = "Something failed", offset: int = 0) -> dict: + return {"type": "error", "error": msg, "timestamp": _ts(offset)} + + +def _decision_entry(decision: str = "Use approach A", offset: int = 0) -> dict: + return {"type": "decision", "decision": decision, "timestamp": _ts(offset)} + + +# ── Trajectory dataclass tests ──────────────────────────────────────────────── + + +class TestTrajectory: + def test_message_count(self): + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_user_msg("hi"), _timmy_msg("hello")], + ) + assert t.message_count == 2 + + def test_tool_call_count(self): + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + tool_calls=[_tool_call(), _tool_call()], + ) + assert t.tool_call_count == 2 + + def test_has_successful_tool_call_when_no_errors(self): + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + tool_calls=[_tool_call()], + errors=[], + ) + assert t.has_successful_tool_call is True + + def test_has_successful_tool_call_false_when_errors(self): + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + tool_calls=[_tool_call()], + errors=[_error_entry()], + ) + assert t.has_successful_tool_call is False + + def test_is_multi_step(self): + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_user_msg("do it"), _timmy_msg("done")], + tool_calls=[_tool_call()], + ) + assert t.is_multi_step is True + + def test_is_not_multi_step_single_message(self): + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_timmy_msg("hello")], + tool_calls=[], + ) + assert t.is_multi_step is False + + def test_to_chat_format_ordering(self): + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_user_msg("question", offset=0), _timmy_msg("answer", offset=2)], + tool_calls=[_tool_call(offset=1)], + ) + chat = t.to_chat_format() + roles = [m["role"] for m in chat] + assert "user" in roles + assert "assistant" in roles + + def test_to_chat_format_empty_content_skipped(self): + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_user_msg(""), _timmy_msg("response")], + ) + chat = t.to_chat_format() + # Empty user message should be skipped + assert all(m["content"] for m in chat) + + +# ── TrajectoryExporter tests ────────────────────────────────────────────────── + + +class TestTrajectoryExporter: + def test_export_empty_logs_dir(self, tmp_path): + (tmp_path / "logs").mkdir() + exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path) + result = exporter.export_week(weeks_ago=0) + assert result == [] + + def test_export_reads_session_files(self, tmp_path): + # Write a session file for this week + today = datetime.now(tz=UTC) + date_str = today.strftime("%Y-%m-%d") + entries = [ + _user_msg("tell me about Python"), + _timmy_msg("Python is great"), + ] + _make_session_log(entries, date_str, tmp_path) + + exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path) + result = exporter.export_week(weeks_ago=0) + assert len(result) >= 1 + + def test_export_skips_old_sessions(self, tmp_path): + # Write a session file for 3 weeks ago + three_weeks_ago = datetime.now(tz=UTC) - timedelta(weeks=3) + date_str = three_weeks_ago.strftime("%Y-%m-%d") + entries = [_user_msg("old message"), _timmy_msg("old response")] + _make_session_log(entries, date_str, tmp_path) + + exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path) + # Request current week — should not include 3-week-old data + result = exporter.export_week(weeks_ago=0) + assert result == [] + + def test_export_segments_by_gap(self, tmp_path): + today = datetime.now(tz=UTC) + date_str = today.strftime("%Y-%m-%d") + + # Two conversations separated by 10 minutes + t1 = (today - timedelta(minutes=15)).isoformat() + t2 = (today - timedelta(minutes=14)).isoformat() + t3 = (today - timedelta(minutes=2)).isoformat() + t4 = (today - timedelta(minutes=1)).isoformat() + + entries = [ + {"type": "message", "role": "user", "content": "first q", "timestamp": t1}, + {"type": "message", "role": "timmy", "content": "first a", "timestamp": t2}, + {"type": "message", "role": "user", "content": "second q", "timestamp": t3}, + {"type": "message", "role": "timmy", "content": "second a", "timestamp": t4}, + ] + _make_session_log(entries, date_str, tmp_path) + + exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path) + result = exporter.export_week(weeks_ago=0) + # Should have at least 1 trajectory (may be 1 or 2 depending on segmentation) + assert len(result) >= 1 + + def test_handles_malformed_log_file(self, tmp_path): + log_dir = tmp_path / "logs" + log_dir.mkdir() + today = datetime.now(tz=UTC).strftime("%Y-%m-%d") + (log_dir / f"session_{today}.jsonl").write_text("not json\n{}\n") + + exporter = TrajectoryExporter(logs_dir=log_dir, repo_root=tmp_path) + # Should not raise, just return empty or partial results + result = exporter.export_week(weeks_ago=0) + assert isinstance(result, list) + + +# ── QualityFilter tests ─────────────────────────────────────────────────────── + + +class TestQualityFilter: + def _make_high_quality(self) -> Trajectory: + return Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_user_msg("do task"), _timmy_msg("done", confidence=0.9)], + tool_calls=[_tool_call(), _tool_call()], + errors=[], + decisions=[_decision_entry()], + ) + + def _make_medium_quality(self) -> Trajectory: + return Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_user_msg("hello"), _timmy_msg("hi")], + tool_calls=[], + errors=[], + ) + + def _make_low_quality(self) -> Trajectory: + return Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_timmy_msg("oops")], # No user message + errors=[_error_entry()], + ) + + def test_high_quality_classification(self): + qf = QualityFilter() + result = qf.assess(self._make_high_quality()) + assert result.quality == TrajectoryQuality.HIGH + assert result.score >= 4.0 + assert result.is_trainable + + def test_medium_quality_classification(self): + qf = QualityFilter() + result = qf.assess(self._make_medium_quality()) + assert result.quality == TrajectoryQuality.MEDIUM + assert result.is_trainable + + def test_low_quality_no_user_message(self): + qf = QualityFilter() + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_timmy_msg("random")], + ) + result = qf.assess(t) + assert result.quality == TrajectoryQuality.LOW + assert not result.is_trainable + + def test_error_penalizes_score(self): + qf = QualityFilter() + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_user_msg("go"), _timmy_msg("fail")], + tool_calls=[_tool_call()], + errors=[_error_entry(), _error_entry()], + ) + result = qf.assess(t) + assert result.score < qf.assess(self._make_high_quality()).score + + def test_low_confidence_penalizes_score(self): + qf = QualityFilter() + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(), + ended_at=_ts(), + messages=[_user_msg("q"), _timmy_msg("a", confidence=0.2)], + ) + result = qf.assess(t) + assert result.score < 1.0 + + def test_filter_returns_stats(self): + qf = QualityFilter() + trajectories = [ + self._make_high_quality(), + self._make_medium_quality(), + self._make_low_quality(), + ] + trainable, stats = qf.filter(trajectories) + assert stats["total"] == 3 + assert stats["accepted"] == len(trainable) + assert stats["high"] + stats["medium"] + stats["low"] == 3 + + def test_filter_empty_list(self): + qf = QualityFilter() + trainable, stats = qf.filter([]) + assert trainable == [] + assert stats["total"] == 0 + assert stats["accepted"] == 0 + + +# ── TrainingDataset tests ───────────────────────────────────────────────────── + + +class TestTrainingDataset: + def _make_result(self, quality=TrajectoryQuality.HIGH, score=5.0) -> object: + from timmy_automations.retrain.quality_filter import QualityResult + + t = Trajectory( + session_date="2026-03-17", + started_at=_ts(-5), + ended_at=_ts(), + messages=[_user_msg("do it"), _timmy_msg("done")], + tool_calls=[_tool_call()], + ) + return QualityResult(trajectory=t, quality=quality, score=score, reasons=[]) + + def test_count_empty_dataset(self, tmp_path): + ds = TrainingDataset( + dataset_path=".loop/retrain/training_data.jsonl", + repo_root=tmp_path, + ) + assert ds.count() == 0 + + def test_append_adds_examples(self, tmp_path): + ds = TrainingDataset(repo_root=tmp_path) + result = ds.append([self._make_result()], "2026-W12") + assert result.new_examples == 1 + assert result.total_examples == 1 + assert ds.count() == 1 + + def test_append_idempotent(self, tmp_path): + ds = TrainingDataset(repo_root=tmp_path) + r = self._make_result() + ds.append([r], "2026-W12") + result2 = ds.append([r], "2026-W12") + # Same trajectory shouldn't be added twice + assert result2.new_examples == 0 + assert ds.count() == 1 + + def test_append_different_weeks(self, tmp_path): + ds = TrainingDataset(repo_root=tmp_path) + r1 = self._make_result() + ds.append([r1], "2026-W11") + ds.append([r1], "2026-W12") + # Different week tags = different records + assert ds.count() == 2 + + def test_dataset_file_is_valid_jsonl(self, tmp_path): + ds = TrainingDataset(repo_root=tmp_path) + ds.append([self._make_result()], "2026-W12") + with open(ds.dataset_path) as f: + lines = [l.strip() for l in f if l.strip()] + assert len(lines) == 1 + record = json.loads(lines[0]) + assert "messages" in record + assert "week" in record + assert "quality" in record + + def test_index_updated_after_append(self, tmp_path): + ds = TrainingDataset(repo_root=tmp_path) + ds.append([self._make_result()], "2026-W12") + index_path = tmp_path / ".loop" / "retrain" / "dataset_index.json" + assert index_path.exists() + index = json.loads(index_path.read_text()) + assert index["total_examples"] == 1 + assert "2026-W12" in index["weeks"] + + +# ── TrainingLog tests ───────────────────────────────────────────────────────── + + +class TestTrainingLog: + def _make_metrics(self, iteration: int = 1) -> CycleMetrics: + return CycleMetrics( + iteration=iteration, + week="2026-W12", + ran_at=datetime.now(tz=UTC).isoformat(), + trajectories_total=10, + trajectories_high=5, + trajectories_medium=3, + trajectories_low=2, + trajectories_accepted=8, + examples_added=5, + dataset_total=5, + train_status="completed", + train_loss=1.2345, + train_duration_seconds=120.5, + adapter_path=".loop/retrain/adapters/iter_0001/adapters.npz", + model_name="hermes4-14b-ft-0001", + notes="First fine-tune cycle complete", + ) + + def test_next_iteration_starts_at_1(self, tmp_path): + log = TrainingLog(repo_root=tmp_path) + assert log.next_iteration() == 1 + + def test_next_iteration_increments(self, tmp_path): + log = TrainingLog(repo_root=tmp_path) + log.record(self._make_metrics(iteration=1)) + assert log.next_iteration() == 2 + + def test_record_creates_log_file(self, tmp_path): + log = TrainingLog(repo_root=tmp_path) + log.record(self._make_metrics()) + assert log.log_path.exists() + + def test_load_all_returns_records(self, tmp_path): + log = TrainingLog(repo_root=tmp_path) + log.record(self._make_metrics(iteration=1)) + log.record(self._make_metrics(iteration=2)) + entries = log.load_all() + assert len(entries) == 2 + assert entries[0]["iteration"] == 1 + + def test_latest_returns_last_entry(self, tmp_path): + log = TrainingLog(repo_root=tmp_path) + log.record(self._make_metrics(iteration=1)) + log.record(self._make_metrics(iteration=2)) + latest = log.latest() + assert latest is not None + assert latest["iteration"] == 2 + + def test_latest_returns_none_when_empty(self, tmp_path): + log = TrainingLog(repo_root=tmp_path) + assert log.latest() is None + + def test_summary_markdown_written(self, tmp_path): + log = TrainingLog(repo_root=tmp_path) + log.record(self._make_metrics()) + summary_path = tmp_path / ".loop" / "retrain" / "training_log.md" + assert summary_path.exists() + content = summary_path.read_text() + assert "AutoLoRA Training Log" in content + assert "2026-W12" in content + assert "completed" in content + + def test_skill_accuracy_in_summary(self, tmp_path): + log = TrainingLog(repo_root=tmp_path) + m = self._make_metrics() + m.skill_accuracy = {"tool_calling": 0.85, "reasoning": 0.72} + log.record(m) + content = (tmp_path / ".loop" / "retrain" / "training_log.md").read_text() + assert "tool_calling" in content + assert "reasoning" in content + + +# ── RetrainOrchestrator integration tests ───────────────────────────────────── + + +class TestRetrainOrchestrator: + def test_run_dry_run_no_data(self, tmp_path): + """Dry run with no session logs should complete without errors.""" + (tmp_path / "logs").mkdir(parents=True) + orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True) + result = orc.run(weeks_ago=0) + assert result.train_status in ("skipped",) + assert result.examples_added == 0 + assert result.iteration == 1 + + def test_run_creates_log_entry(self, tmp_path): + (tmp_path / "logs").mkdir(parents=True) + orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True) + orc.run(weeks_ago=0) + log = TrainingLog(repo_root=tmp_path) + entries = log.load_all() + assert len(entries) == 1 + + def test_run_with_session_data(self, tmp_path): + """Run with actual session data — should export, filter, and log.""" + today = datetime.now(tz=UTC) + date_str = today.strftime("%Y-%m-%d") + entries = [ + _user_msg("deploy the service", offset=-10), + _tool_call("bash", "deployed successfully", offset=-9), + _tool_call("bash", "health check ok", offset=-8), + _timmy_msg("Service deployed and healthy", confidence=0.92, offset=-7), + _user_msg("run the tests", offset=-6), + _tool_call("bash", "All tests passed", offset=-5), + _timmy_msg("All 42 tests passed", confidence=0.95, offset=-4), + ] + _make_session_log(entries, date_str, tmp_path) + + orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True) + result = orc.run(weeks_ago=0) + + assert result.trajectories_exported >= 1 + assert result.iteration == 1 + # In dry_run mode, fine-tune is skipped but trajectories should be processed + assert result.train_status == "skipped" + + def test_iteration_increments_on_second_run(self, tmp_path): + (tmp_path / "logs").mkdir(parents=True) + orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True) + r1 = orc.run(weeks_ago=0) + r2 = orc.run(weeks_ago=0) + assert r2.iteration == r1.iteration + 1 + + def test_automations_json_has_retrain_entry(self): + """Verify the retrain automation is registered in automations.json.""" + config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json" + assert config_path.exists() + manifest = json.loads(config_path.read_text()) + ids = [a["id"] for a in manifest.get("automations", [])] + assert "retrain" in ids + + def test_retrain_automation_config(self): + """Verify retrain automation has correct schedule and config.""" + config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json" + manifest = json.loads(config_path.read_text()) + retrain = next(a for a in manifest["automations"] if a["id"] == "retrain") + assert retrain["schedule"] == "weekly_sunday" + assert retrain["trigger"] == "scheduled" + assert retrain["config"]["base_model"] == "hermes4-14b" + assert retrain["config"]["weeks_ago"] == 1 + + +_REPO_ROOT = Path(__file__).resolve().parent.parent.parent diff --git a/timmy_automations/config/automations.json b/timmy_automations/config/automations.json index 8478c05b..cdca59f7 100644 --- a/timmy_automations/config/automations.json +++ b/timmy_automations/config/automations.json @@ -4,7 +4,7 @@ "_health_snapshot": { "note": "Quick health check before coding — CI, P0/P1 issues, flakiness" }, - "last_updated": "2026-03-21", + "last_updated": "2026-03-23", "automations": [ { "id": "cycle_retro", @@ -268,6 +268,36 @@ "ci_timeout_seconds": 5 }, "outputs": [] + }, + { + "id": "retrain", + "name": "AutoLoRA Continuous Improvement Loop", + "description": "Weekly sovereignty loop — exports trajectories, filters quality, appends to training dataset, triggers LoRA fine-tune, loads new adapter, and logs iteration metrics", + "script": "timmy_automations/retrain/retrain.py", + "category": "autolora", + "enabled": true, + "trigger": "scheduled", + "schedule": "weekly_sunday", + "executable": "python3", + "epic": "#1091", + "pipeline": "AutoLoRA Sovereignty Loop (Step 6 of 7)", + "config": { + "weeks_ago": 1, + "base_model": "hermes4-14b", + "dry_run": false, + "logs_dir": "logs", + "dataset_path": ".loop/retrain/training_data.jsonl", + "adapter_dir": ".loop/retrain/adapters", + "training_log_path": ".loop/retrain/training_log.jsonl", + "training_summary_path": ".loop/retrain/training_log.md" + }, + "outputs": [ + ".loop/retrain/training_data.jsonl", + ".loop/retrain/dataset_index.json", + ".loop/retrain/training_log.jsonl", + ".loop/retrain/training_log.md", + ".loop/retrain/adapters/" + ] } ] } diff --git a/timmy_automations/retrain/__init__.py b/timmy_automations/retrain/__init__.py new file mode 100644 index 00000000..228f54eb --- /dev/null +++ b/timmy_automations/retrain/__init__.py @@ -0,0 +1,26 @@ +"""AutoLoRA continuous improvement loop — sovereignty engine for Timmy. + +Implements the weekly retrain cycle: + Work → Record trajectories → Export weekly → Filter quality + → LoRA fine-tune → Load adapter → Model improves → Repeat + +Epic: #1091 — Project Bannerlord +Pipeline: AutoLoRA Sovereignty Loop (Step 6 of 7) +Refs: #1105 +""" + +from timmy_automations.retrain.quality_filter import QualityFilter, TrajectoryQuality +from timmy_automations.retrain.retrain import RetrainOrchestrator, RetrainResult +from timmy_automations.retrain.training_dataset import TrainingDataset +from timmy_automations.retrain.training_log import TrainingLog +from timmy_automations.retrain.trajectory_exporter import TrajectoryExporter + +__all__ = [ + "QualityFilter", + "RetrainOrchestrator", + "RetrainResult", + "TrainingDataset", + "TrainingLog", + "TrajectoryExporter", + "TrajectoryQuality", +] diff --git a/timmy_automations/retrain/lora_trainer.py b/timmy_automations/retrain/lora_trainer.py new file mode 100644 index 00000000..85c0a3fb --- /dev/null +++ b/timmy_automations/retrain/lora_trainer.py @@ -0,0 +1,262 @@ +"""LoRA trainer — triggers fine-tune job and loads the resulting adapter. + +Supports two backends: +1. mlx-lm (default, Apple Silicon) — `mlx_lm.lora` CLI +2. Ollama create (adapter packaging into a new Ollama model) + +Graceful degradation: if neither backend is available, logs a warning +and returns a skipped result — the rest of the loop continues. + +Refs: #1105 +""" + +from __future__ import annotations + +import json +import logging +import os +import shutil +import subprocess +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path + +logger = logging.getLogger(__name__) + +_DEFAULT_BASE_MODEL = "hermes4-14b" +_DEFAULT_ADAPTER_DIR = ".loop/retrain/adapters" +_MLX_LM_BIN = "mlx_lm.lora" +_OLLAMA_BIN = "ollama" + + +@dataclass +class TrainResult: + """Result of a LoRA fine-tune run.""" + + status: str # "completed" | "skipped" | "failed" + adapter_path: str | None + model_name: str | None + iteration: int + duration_seconds: float + message: str + train_loss: float | None = None + + +class LoRATrainer: + """Orchestrates LoRA fine-tuning and adapter loading. + + Workflow: + 1. Run mlx_lm.lora fine-tune on the training dataset + 2. Save the resulting adapter to .loop/retrain/adapters// + 3. Create (or update) an Ollama model that uses the new adapter + """ + + def __init__( + self, + base_model: str = _DEFAULT_BASE_MODEL, + adapter_dir: str | Path | None = None, + repo_root: str | Path | None = None, + dry_run: bool = False, + ): + if repo_root is None: + repo_root = Path(__file__).resolve().parent.parent.parent + self._repo_root = Path(repo_root) + + self._base_model = base_model + self._adapter_dir = self._repo_root / (adapter_dir or _DEFAULT_ADAPTER_DIR) + self._adapter_dir.mkdir(parents=True, exist_ok=True) + self._dry_run = dry_run + + def train(self, dataset_path: Path, iteration: int) -> TrainResult: + """Run LoRA fine-tuning on the dataset. + + Args: + dataset_path: Path to the JSONL training dataset. + iteration: Current fine-tune iteration number (used for naming). + + Returns: + TrainResult with status, adapter path, and metrics. + """ + started = datetime.now(tz=UTC) + + if not dataset_path.exists() or dataset_path.stat().st_size == 0: + return TrainResult( + status="skipped", + adapter_path=None, + model_name=None, + iteration=iteration, + duration_seconds=0.0, + message="Training dataset is empty — skipping fine-tune", + ) + + if self._dry_run: + logger.info("[dry-run] Would fine-tune %s on %s", self._base_model, dataset_path) + adapter_path = self._adapter_dir / f"iter_{iteration:04d}" / "adapters.npz" + return TrainResult( + status="skipped", + adapter_path=str(adapter_path), + model_name=f"{self._base_model}-ft-{iteration:04d}", + iteration=iteration, + duration_seconds=0.0, + message="dry-run mode — no training performed", + ) + + # Determine which backend is available + if shutil.which(_MLX_LM_BIN): + return self._train_mlx(dataset_path, iteration, started) + else: + logger.warning( + "%s not found — skipping LoRA fine-tune (install mlx-lm to enable)", + _MLX_LM_BIN, + ) + return TrainResult( + status="skipped", + adapter_path=None, + model_name=None, + iteration=iteration, + duration_seconds=0.0, + message=( + f"{_MLX_LM_BIN} not available. " + "Install mlx-lm on Apple Silicon to enable LoRA fine-tuning." + ), + ) + + def _train_mlx( + self, dataset_path: Path, iteration: int, started: datetime + ) -> TrainResult: + """Run mlx_lm.lora fine-tune.""" + adapter_out = self._adapter_dir / f"iter_{iteration:04d}" + adapter_out.mkdir(parents=True, exist_ok=True) + + cmd = [ + _MLX_LM_BIN, + "--model", self._base_model, + "--data", str(dataset_path), + "--adapter-path", str(adapter_out), + "--train", + "--iters", "100", + "--batch-size", "1", + "--learning-rate", "1e-5", + ] + + logger.info("Starting mlx-lm LoRA fine-tune: iteration %d", iteration) + logger.info("Command: %s", " ".join(cmd)) + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=3600, # 1 hour max + env={**os.environ, "PYTHONUNBUFFERED": "1"}, + ) + except subprocess.TimeoutExpired: + duration = (datetime.now(tz=UTC) - started).total_seconds() + return TrainResult( + status="failed", + adapter_path=None, + model_name=None, + iteration=iteration, + duration_seconds=duration, + message="Fine-tune timed out after 1 hour", + ) + except Exception as exc: + duration = (datetime.now(tz=UTC) - started).total_seconds() + return TrainResult( + status="failed", + adapter_path=None, + model_name=None, + iteration=iteration, + duration_seconds=duration, + message=f"Fine-tune subprocess error: {exc}", + ) + + duration = (datetime.now(tz=UTC) - started).total_seconds() + + if result.returncode != 0: + logger.error("mlx-lm fine-tune failed: %s", result.stderr[:500]) + return TrainResult( + status="failed", + adapter_path=None, + model_name=None, + iteration=iteration, + duration_seconds=duration, + message=f"mlx_lm.lora exited {result.returncode}: {result.stderr[:300]}", + ) + + # Parse final train loss from stdout if available + train_loss = _parse_train_loss(result.stdout) + + adapter_file = adapter_out / "adapters.npz" + model_name = f"{self._base_model}-ft-{iteration:04d}" + + # Attempt to register with Ollama + ollama_ok = self._register_ollama_adapter(adapter_out, model_name) + if not ollama_ok: + logger.warning("Ollama adapter registration failed — adapter saved locally") + + logger.info( + "Fine-tune complete: iteration=%d loss=%.4f duration=%.1fs adapter=%s", + iteration, + train_loss or 0.0, + duration, + adapter_file, + ) + + return TrainResult( + status="completed", + adapter_path=str(adapter_file), + model_name=model_name, + iteration=iteration, + duration_seconds=duration, + message=f"LoRA fine-tune completed successfully in {duration:.0f}s", + train_loss=train_loss, + ) + + def _register_ollama_adapter(self, adapter_dir: Path, model_name: str) -> bool: + """Create an Ollama model entry for the new adapter. + + Writes a minimal Modelfile and runs `ollama create`. + """ + if not shutil.which(_OLLAMA_BIN): + logger.debug("Ollama not found — skipping adapter registration") + return False + + modelfile_content = ( + f"FROM {self._base_model}\n" + f"ADAPTER {adapter_dir}\n" + ) + modelfile_path = adapter_dir / "Modelfile" + try: + modelfile_path.write_text(modelfile_content) + result = subprocess.run( + [_OLLAMA_BIN, "create", model_name, "-f", str(modelfile_path)], + capture_output=True, + text=True, + timeout=300, + ) + if result.returncode == 0: + logger.info("Ollama model registered: %s", model_name) + return True + else: + logger.warning("ollama create failed: %s", result.stderr[:200]) + return False + except Exception as exc: + logger.warning("Ollama adapter registration error: %s", exc) + return False + + +def _parse_train_loss(stdout: str) -> float | None: + """Extract the final training loss from mlx-lm stdout.""" + loss: float | None = None + for line in stdout.splitlines(): + line_lower = line.lower() + if "train loss" in line_lower or "loss:" in line_lower: + parts = line.split() + for i, part in enumerate(parts): + if "loss" in part.lower() and i + 1 < len(parts): + try: + loss = float(parts[i + 1].strip(",:")) + except ValueError: + pass + return loss diff --git a/timmy_automations/retrain/quality_filter.py b/timmy_automations/retrain/quality_filter.py new file mode 100644 index 00000000..4d493a00 --- /dev/null +++ b/timmy_automations/retrain/quality_filter.py @@ -0,0 +1,172 @@ +"""Quality filter — keeps only high-value trajectories for LoRA training. + +Criteria for a high-quality training example: +1. Tool calls succeeded (tool calls present, no error entries) +2. Multi-step tasks completed (≥2 messages + ≥1 tool call) +3. No low-confidence signals (confidence < 0.5 on any Timmy message) +4. Minimum meaningful exchange (≥1 user message + ≥1 Timmy message) + +Refs: #1105 +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from enum import StrEnum + +from timmy_automations.retrain.trajectory_exporter import Trajectory + +logger = logging.getLogger(__name__) + +_MIN_CONFIDENCE = 0.5 + + +class TrajectoryQuality(StrEnum): + """Quality classification for a trajectory.""" + + HIGH = "high" # Multi-step + tool success — ideal training data + MEDIUM = "medium" # Single exchange, no errors — acceptable + LOW = "low" # Error-prone or trivial — skip + + +@dataclass +class QualityResult: + """Result of quality assessment for a single trajectory.""" + + trajectory: Trajectory + quality: TrajectoryQuality + score: float + reasons: list[str] + + @property + def is_trainable(self) -> bool: + return self.quality in (TrajectoryQuality.HIGH, TrajectoryQuality.MEDIUM) + + +class QualityFilter: + """Filters trajectories to keep only those worth training on. + + Scoring: + - +1 pt: base score for any valid clean exchange (no errors) + - +3 pts: multi-step task (≥2 messages + ≥1 tool call) + - +2 pts: tool calls present and no errors + - +1 pt: decision recorded (deliberate choice made) + - -2 pts: any error entry + - -1 pt: any low-confidence response (confidence < 0.5) + + HIGH ≥ 4, MEDIUM 1–3, LOW ≤ 0 + """ + + def __init__(self, min_confidence: float = _MIN_CONFIDENCE): + self._min_confidence = min_confidence + + def assess(self, trajectory: Trajectory) -> QualityResult: + """Score and classify a single trajectory.""" + score = 0.0 + reasons: list[str] = [] + + # Minimum viable exchange check + user_msgs = [m for m in trajectory.messages if m.get("role") == "user"] + timmy_msgs = [m for m in trajectory.messages if m.get("role") == "timmy"] + + if not user_msgs or not timmy_msgs: + return QualityResult( + trajectory=trajectory, + quality=TrajectoryQuality.LOW, + score=0.0, + reasons=["Missing user or assistant messages — not a valid exchange"], + ) + + # Multi-step bonus + if trajectory.is_multi_step: + score += 3.0 + reasons.append( + f"Multi-step task: {trajectory.message_count} messages, " + f"{trajectory.tool_call_count} tool calls" + ) + + # Base score for any clean exchange (user + timmy, no tool call required) + if trajectory.error_count == 0: + score += 1.0 + reasons.append("Clean exchange (no errors)") + + # Tool call quality + if trajectory.tool_call_count > 0: + if trajectory.error_count == 0: + score += 2.0 + reasons.append( + f"All {trajectory.tool_call_count} tool call(s) succeeded" + ) + else: + score -= 2.0 + reasons.append( + f"{trajectory.error_count} error(s) during {trajectory.tool_call_count} tool call(s)" + ) + elif trajectory.error_count > 0: + score -= 2.0 + reasons.append(f"{trajectory.error_count} error(s) with no tool calls") + + # Decision bonus + if trajectory.decisions: + score += 1.0 + reasons.append(f"Decisions recorded: {len(trajectory.decisions)}") + + # Confidence penalty + low_conf = [ + m + for m in timmy_msgs + if m.get("confidence") is not None + and m["confidence"] < self._min_confidence + ] + if low_conf: + score -= len(low_conf) + reasons.append( + f"{len(low_conf)} low-confidence response(s) (threshold={self._min_confidence})" + ) + + # Classify + if score >= 4.0: + quality = TrajectoryQuality.HIGH + elif score >= 1.0: + quality = TrajectoryQuality.MEDIUM + else: + quality = TrajectoryQuality.LOW + + return QualityResult( + trajectory=trajectory, + quality=quality, + score=score, + reasons=reasons, + ) + + def filter( + self, trajectories: list[Trajectory] + ) -> tuple[list[QualityResult], dict[str, int]]: + """Assess all trajectories and return trainable ones with stats. + + Returns: + (trainable_results, stats_dict) where stats_dict has keys + 'total', 'high', 'medium', 'low', 'accepted'. + """ + results = [self.assess(t) for t in trajectories] + trainable = [r for r in results if r.is_trainable] + + stats = { + "total": len(results), + "high": sum(1 for r in results if r.quality == TrajectoryQuality.HIGH), + "medium": sum(1 for r in results if r.quality == TrajectoryQuality.MEDIUM), + "low": sum(1 for r in results if r.quality == TrajectoryQuality.LOW), + "accepted": len(trainable), + } + + logger.info( + "Quality filter: %d/%d accepted (high=%d medium=%d low=%d)", + stats["accepted"], + stats["total"], + stats["high"], + stats["medium"], + stats["low"], + ) + + return trainable, stats diff --git a/timmy_automations/retrain/retrain.py b/timmy_automations/retrain/retrain.py new file mode 100644 index 00000000..f7843b6e --- /dev/null +++ b/timmy_automations/retrain/retrain.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +"""AutoLoRA continuous improvement loop — the sovereignty retrain script. + +Implements the weekly retrain cycle end-to-end: + Work → Record trajectories → Export weekly → Filter quality + → LoRA fine-tune → Load adapter → Model improves → Repeat forever + +Run: + python3 timmy_automations/retrain/retrain.py + python3 timmy_automations/retrain/retrain.py --dry-run + python3 timmy_automations/retrain/retrain.py --weeks-ago 1 + +Epic: #1091 — Project Bannerlord +Pipeline: AutoLoRA Sovereignty Loop (Step 6 of 7) +Refs: #1105 +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path + +# Allow running directly from repo root +_REPO_ROOT = Path(__file__).resolve().parent.parent.parent +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from timmy_automations.retrain.lora_trainer import LoRATrainer +from timmy_automations.retrain.quality_filter import QualityFilter +from timmy_automations.retrain.training_dataset import TrainingDataset +from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog +from timmy_automations.retrain.trajectory_exporter import TrajectoryExporter + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(name)s: %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", +) +logger = logging.getLogger("retrain") + + +@dataclass +class RetrainResult: + """Result of a complete retrain cycle.""" + + iteration: int + week: str + trajectories_exported: int + trajectories_accepted: int + examples_added: int + dataset_total: int + train_status: str + adapter_path: str | None + model_name: str | None + train_loss: float | None + duration_seconds: float + notes: str + + +class RetrainOrchestrator: + """Orchestrates the complete AutoLoRA continuous improvement loop. + + Step 1: Export this week's conversation trajectories from session logs + Step 2: Filter for high-quality exchanges + Step 3: Append to the training dataset + Step 4: Trigger LoRA fine-tune + Step 5: Load the new adapter (via Ollama) + Step 6: Log iteration, loss, skill accuracy + """ + + def __init__( + self, + base_model: str = "hermes4-14b", + repo_root: str | Path | None = None, + dry_run: bool = False, + ): + if repo_root is None: + repo_root = _REPO_ROOT + self._repo_root = Path(repo_root) + self._dry_run = dry_run + + self.exporter = TrajectoryExporter(repo_root=self._repo_root) + self.quality_filter = QualityFilter() + self.dataset = TrainingDataset(repo_root=self._repo_root) + self.trainer = LoRATrainer( + base_model=base_model, + repo_root=self._repo_root, + dry_run=dry_run, + ) + self.log = TrainingLog(repo_root=self._repo_root) + + def run(self, weeks_ago: int = 1) -> RetrainResult: + """Execute one complete retrain cycle. + + Args: + weeks_ago: Which week to process. 0 = current week (partial), + 1 = last week (default, Sunday night run), etc. + + Returns: + RetrainResult with full cycle summary. + """ + started = datetime.now(tz=UTC) + iteration = self.log.next_iteration() + + # Determine ISO week tag + from datetime import timedelta + now = datetime.now(tz=UTC) + target_date = now - timedelta(weeks=weeks_ago) + week_tag = f"{target_date.year}-W{target_date.isocalendar().week:02d}" + + logger.info( + "=== AutoLoRA Retrain Cycle %d | Week: %s | dry_run=%s ===", + iteration, + week_tag, + self._dry_run, + ) + + # Step 1: Export trajectories + logger.info("Step 1: Exporting trajectories for %s...", week_tag) + trajectories = self.exporter.export_week(weeks_ago=weeks_ago) + logger.info("Exported %d raw trajectories", len(trajectories)) + + # Step 2: Quality filter + logger.info("Step 2: Applying quality filter...") + trainable, filter_stats = self.quality_filter.filter(trajectories) + logger.info( + "Quality filter: %d/%d accepted (high=%d medium=%d low=%d)", + filter_stats["accepted"], + filter_stats["total"], + filter_stats["high"], + filter_stats["medium"], + filter_stats["low"], + ) + + # Step 3: Append to dataset + logger.info("Step 3: Appending to training dataset...") + append_result = self.dataset.append(trainable, week_tag) + logger.info( + "Dataset: +%d new examples (%d total)", + append_result.new_examples, + append_result.total_examples, + ) + + # Step 4: LoRA fine-tune + logger.info("Step 4: Triggering LoRA fine-tune (iteration=%d)...", iteration) + train_result = self.trainer.train( + dataset_path=self.dataset.dataset_path, + iteration=iteration, + ) + logger.info( + "Train result: status=%s loss=%s duration=%.1fs", + train_result.status, + train_result.train_loss, + train_result.duration_seconds, + ) + + # Step 5 & 6: Log cycle + duration = (datetime.now(tz=UTC) - started).total_seconds() + metrics = CycleMetrics( + iteration=iteration, + week=week_tag, + ran_at=started.isoformat(), + trajectories_total=filter_stats["total"], + trajectories_high=filter_stats["high"], + trajectories_medium=filter_stats["medium"], + trajectories_low=filter_stats["low"], + trajectories_accepted=filter_stats["accepted"], + examples_added=append_result.new_examples, + dataset_total=append_result.total_examples, + train_status=train_result.status, + train_loss=train_result.train_loss, + train_duration_seconds=train_result.duration_seconds, + adapter_path=train_result.adapter_path, + model_name=train_result.model_name, + notes=train_result.message, + ) + self.log.record(metrics) + + result = RetrainResult( + iteration=iteration, + week=week_tag, + trajectories_exported=len(trajectories), + trajectories_accepted=filter_stats["accepted"], + examples_added=append_result.new_examples, + dataset_total=append_result.total_examples, + train_status=train_result.status, + adapter_path=train_result.adapter_path, + model_name=train_result.model_name, + train_loss=train_result.train_loss, + duration_seconds=duration, + notes=train_result.message, + ) + + logger.info( + "=== Cycle %d complete: status=%s examples_added=%d total=%.1fs ===", + iteration, + train_result.status, + append_result.new_examples, + duration, + ) + + return result + + +def _print_result(result: RetrainResult, as_json: bool = False) -> None: + """Print cycle result to stdout.""" + if as_json: + print( + json.dumps( + { + "iteration": result.iteration, + "week": result.week, + "trajectories_exported": result.trajectories_exported, + "trajectories_accepted": result.trajectories_accepted, + "examples_added": result.examples_added, + "dataset_total": result.dataset_total, + "train_status": result.train_status, + "adapter_path": result.adapter_path, + "model_name": result.model_name, + "train_loss": result.train_loss, + "duration_seconds": result.duration_seconds, + "notes": result.notes, + }, + indent=2, + ) + ) + return + + print(f"\n{'='*60}") + print(f" AutoLoRA Retrain — Cycle {result.iteration}") + print(f" Week: {result.week}") + print(f"{'='*60}") + print(f" Trajectories: {result.trajectories_exported} exported, {result.trajectories_accepted} accepted") + print(f" Dataset: +{result.examples_added} examples ({result.dataset_total} total)") + print(f" Fine-tune: {result.train_status}") + if result.train_loss is not None: + print(f" Train loss: {result.train_loss:.4f}") + if result.model_name: + print(f" New model: {result.model_name}") + if result.adapter_path: + print(f" Adapter: {result.adapter_path}") + print(f" Duration: {result.duration_seconds:.1f}s") + print(f" Notes: {result.notes}") + print(f"{'='*60}\n") + + +def main() -> int: + parser = argparse.ArgumentParser( + description="AutoLoRA continuous improvement loop — sovereignty engine for Timmy" + ) + parser.add_argument( + "--weeks-ago", + type=int, + default=1, + help="Which week to process: 0=current (partial), 1=last week (default)", + ) + parser.add_argument( + "--base-model", + default="hermes4-14b", + help="Ollama base model name (default: hermes4-14b)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Export and filter trajectories but skip actual fine-tuning", + ) + parser.add_argument( + "--json", + action="store_true", + dest="as_json", + help="Output result as JSON", + ) + args = parser.parse_args() + + orchestrator = RetrainOrchestrator( + base_model=args.base_model, + dry_run=args.dry_run, + ) + result = orchestrator.run(weeks_ago=args.weeks_ago) + _print_result(result, as_json=args.as_json) + + # Exit 0 even on skipped/failed training — the loop must continue + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/timmy_automations/retrain/training_dataset.py b/timmy_automations/retrain/training_dataset.py new file mode 100644 index 00000000..d49701b9 --- /dev/null +++ b/timmy_automations/retrain/training_dataset.py @@ -0,0 +1,180 @@ +"""Training dataset manager — appends filtered trajectories to a JSONL training file. + +Maintains a growing dataset of high-quality conversation examples in the +chat-format expected by mlx-lm / HuggingFace fine-tuning pipelines. + +Output format (one JSON object per line): + {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]} + +Refs: #1105 +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path + +from timmy_automations.retrain.quality_filter import QualityResult + +logger = logging.getLogger(__name__) + +_DEFAULT_DATASET_PATH = ".loop/retrain/training_data.jsonl" +_DEFAULT_INDEX_PATH = ".loop/retrain/dataset_index.json" + + +@dataclass +class AppendResult: + """Result of appending trajectories to the training dataset.""" + + new_examples: int + total_examples: int + dataset_path: str + week_tag: str + + +class TrainingDataset: + """Manages the LoRA training dataset file. + + Each entry is a chat-format example: + {"messages": [...], "week": "2026-W12", "quality": "high", "added_at": "..."} + """ + + def __init__( + self, + dataset_path: str | Path | None = None, + index_path: str | Path | None = None, + repo_root: str | Path | None = None, + ): + if repo_root is None: + repo_root = Path(__file__).resolve().parent.parent.parent + self._repo_root = Path(repo_root) + + self._dataset_path = self._repo_root / ( + dataset_path or _DEFAULT_DATASET_PATH + ) + self._index_path = self._repo_root / ( + index_path or _DEFAULT_INDEX_PATH + ) + + self._dataset_path.parent.mkdir(parents=True, exist_ok=True) + + @property + def dataset_path(self) -> Path: + return self._dataset_path + + def count(self) -> int: + """Return the number of examples currently in the dataset.""" + if not self._dataset_path.exists(): + return 0 + count = 0 + with open(self._dataset_path) as f: + for line in f: + if line.strip(): + count += 1 + return count + + def append( + self, quality_results: list[QualityResult], week_tag: str + ) -> AppendResult: + """Append high-quality trajectories to the training dataset. + + Deduplicates by (week_tag, session_date, started_at) so re-running + the export for the same week is idempotent. + + Args: + quality_results: Filtered, trainable quality results. + week_tag: ISO week string e.g. "2026-W12". + + Returns: + AppendResult with counts. + """ + existing_keys = self._load_existing_keys() + new_count = 0 + added_at = datetime.now(tz=UTC).isoformat() + + with open(self._dataset_path, "a") as f: + for result in quality_results: + traj = result.trajectory + dedup_key = ( + f"{week_tag}|{traj.session_date}|{traj.started_at}" + ) + if dedup_key in existing_keys: + logger.debug("Skipping duplicate trajectory: %s", dedup_key) + continue + + chat_messages = traj.to_chat_format() + if len(chat_messages) < 2: + logger.debug( + "Skipping trajectory with %d chat messages (need ≥2)", + len(chat_messages), + ) + continue + + record = { + "messages": chat_messages, + "week": week_tag, + "quality": result.quality.value, + "score": result.score, + "session_date": traj.session_date, + "started_at": traj.started_at, + "tool_calls": traj.tool_call_count, + "added_at": added_at, + } + f.write(json.dumps(record) + "\n") + existing_keys.add(dedup_key) + new_count += 1 + + total = self.count() + self._update_index(week_tag, new_count, total) + logger.info( + "Dataset: appended %d new examples (total=%d)", new_count, total + ) + + return AppendResult( + new_examples=new_count, + total_examples=total, + dataset_path=str(self._dataset_path), + week_tag=week_tag, + ) + + def _load_existing_keys(self) -> set[str]: + """Load deduplication keys from the existing dataset.""" + keys: set[str] = set() + if not self._dataset_path.exists(): + return keys + with open(self._dataset_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + week = record.get("week", "") + session_date = record.get("session_date", "") + started_at = record.get("started_at", "") + keys.add(f"{week}|{session_date}|{started_at}") + except json.JSONDecodeError: + continue + return keys + + def _update_index(self, week_tag: str, new_count: int, total: int) -> None: + """Update the dataset index JSON with latest run metadata.""" + index: dict = {} + if self._index_path.exists(): + try: + index = json.loads(self._index_path.read_text()) + except (json.JSONDecodeError, OSError): + index = {} + + index.setdefault("weeks", {}) + index["weeks"][week_tag] = { + "examples_added": new_count, + "updated_at": datetime.now(tz=UTC).isoformat(), + } + index["total_examples"] = total + index["last_updated"] = datetime.now(tz=UTC).isoformat() + + self._index_path.write_text(json.dumps(index, indent=2)) diff --git a/timmy_automations/retrain/training_log.py b/timmy_automations/retrain/training_log.py new file mode 100644 index 00000000..0c9f899b --- /dev/null +++ b/timmy_automations/retrain/training_log.py @@ -0,0 +1,183 @@ +"""Training log — records each fine-tune cycle with metrics and skill deltas. + +Writes to .loop/retrain/training_log.jsonl (one entry per cycle) and +maintains a human-readable .loop/retrain/training_log.md summary. + +Each log entry captures: +- Iteration count +- Week processed +- Quality filter stats +- Examples added to dataset +- LoRA train result (loss, duration, adapter path) +- Skill accuracy deltas (from smoke tests) + +Refs: #1105 +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import asdict, dataclass, field +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + +_DEFAULT_LOG_PATH = ".loop/retrain/training_log.jsonl" +_DEFAULT_SUMMARY_PATH = ".loop/retrain/training_log.md" + + +@dataclass +class CycleMetrics: + """Metrics for a single retrain cycle.""" + + iteration: int + week: str + ran_at: str + + # Quality filter + trajectories_total: int = 0 + trajectories_high: int = 0 + trajectories_medium: int = 0 + trajectories_low: int = 0 + trajectories_accepted: int = 0 + + # Dataset + examples_added: int = 0 + dataset_total: int = 0 + + # Training + train_status: str = "skipped" + train_loss: float | None = None + train_duration_seconds: float = 0.0 + adapter_path: str | None = None + model_name: str | None = None + + # Skill accuracy (optional, from smoke tests) + skill_accuracy: dict[str, float] = field(default_factory=dict) + skill_delta: dict[str, float] = field(default_factory=dict) + + # Human-readable summary + notes: str = "" + + +class TrainingLog: + """Persistent log of all retrain cycles.""" + + def __init__( + self, + log_path: str | Path | None = None, + summary_path: str | Path | None = None, + repo_root: str | Path | None = None, + ): + if repo_root is None: + repo_root = Path(__file__).resolve().parent.parent.parent + self._repo_root = Path(repo_root) + + self._log_path = self._repo_root / (log_path or _DEFAULT_LOG_PATH) + self._summary_path = self._repo_root / (summary_path or _DEFAULT_SUMMARY_PATH) + self._log_path.parent.mkdir(parents=True, exist_ok=True) + + @property + def log_path(self) -> Path: + return self._log_path + + def next_iteration(self) -> int: + """Return the next iteration number (1-indexed).""" + entries = self.load_all() + if not entries: + return 1 + return max(e.get("iteration", 0) for e in entries) + 1 + + def record(self, metrics: CycleMetrics) -> None: + """Append a cycle metrics record to the log.""" + entry = asdict(metrics) + with open(self._log_path, "a") as f: + f.write(json.dumps(entry) + "\n") + + self._update_summary(metrics) + logger.info( + "Training log: iteration=%d week=%s status=%s examples_added=%d", + metrics.iteration, + metrics.week, + metrics.train_status, + metrics.examples_added, + ) + + def load_all(self) -> list[dict[str, Any]]: + """Load all cycle records from the log.""" + if not self._log_path.exists(): + return [] + entries: list[dict[str, Any]] = [] + with open(self._log_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + logger.debug("Skipping malformed log entry") + return entries + + def latest(self) -> dict[str, Any] | None: + """Return the most recent cycle record.""" + entries = self.load_all() + return entries[-1] if entries else None + + def _update_summary(self, metrics: CycleMetrics) -> None: + """Rewrite the markdown summary with all cycles.""" + all_entries = self.load_all() + + lines = [ + "# AutoLoRA Training Log\n", + f"*Updated: {datetime.now(tz=UTC).isoformat()}*\n", + f"*Total iterations: {len(all_entries)}*\n", + "", + "## Cycles\n", + "| # | Week | Status | Loss | Examples | Duration |", + "|---|------|--------|------|----------|----------|", + ] + + for entry in reversed(all_entries[-20:]): # Last 20 cycles + loss = f"{entry.get('train_loss', 0.0) or 0.0:.4f}" if entry.get("train_loss") else "—" + lines.append( + f"| {entry.get('iteration', '?')} " + f"| {entry.get('week', '?')} " + f"| {entry.get('train_status', '?')} " + f"| {loss} " + f"| +{entry.get('examples_added', 0)} ({entry.get('dataset_total', 0)} total) " + f"| {entry.get('train_duration_seconds', 0.0):.0f}s |" + ) + + lines.append("") + lines.append("## Skill Accuracy Over Time\n") + + # Collect all unique skills + all_skills: set[str] = set() + for entry in all_entries: + all_skills.update(entry.get("skill_accuracy", {}).keys()) + + if all_skills: + skill_header = "| # | Week | " + " | ".join(sorted(all_skills)) + " |" + skill_sep = "|---|------|" + "|".join("---" for _ in all_skills) + "|" + lines.extend([skill_header, skill_sep]) + for entry in reversed(all_entries[-10:]): + acc = entry.get("skill_accuracy", {}) + row = f"| {entry.get('iteration', '?')} | {entry.get('week', '?')} | " + row += " | ".join( + f"{acc.get(s, 0.0):.0%}" if s in acc else "—" + for s in sorted(all_skills) + ) + row += " |" + lines.append(row) + else: + lines.append("*No skill accuracy data yet — run smoke tests after fine-tuning.*") + + lines.append("") + if metrics.notes: + lines.append(f"## Latest Notes\n\n{metrics.notes}\n") + + self._summary_path.write_text("\n".join(lines)) diff --git a/timmy_automations/retrain/trajectory_exporter.py b/timmy_automations/retrain/trajectory_exporter.py new file mode 100644 index 00000000..a1f2fe11 --- /dev/null +++ b/timmy_automations/retrain/trajectory_exporter.py @@ -0,0 +1,255 @@ +"""Trajectory exporter — reads session JSONL logs and extracts conversation trajectories. + +A trajectory is a coherent sequence of messages + tool calls that form +a single task attempt. Each trajectory becomes one training example. + +Refs: #1105 +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from datetime import UTC, datetime, timedelta +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + +_LOGS_DIR_DEFAULT = "logs" +_SESSION_GLOB = "session_*.jsonl" + + +@dataclass +class Trajectory: + """A single conversation trajectory extracted from session logs.""" + + session_date: str + started_at: str + ended_at: str + messages: list[dict[str, Any]] = field(default_factory=list) + tool_calls: list[dict[str, Any]] = field(default_factory=list) + errors: list[dict[str, Any]] = field(default_factory=list) + decisions: list[dict[str, Any]] = field(default_factory=list) + + @property + def message_count(self) -> int: + return len(self.messages) + + @property + def tool_call_count(self) -> int: + return len(self.tool_calls) + + @property + def error_count(self) -> int: + return len(self.errors) + + @property + def has_successful_tool_call(self) -> bool: + """True if any tool call succeeded (no error entry follows it).""" + return self.tool_call_count > 0 and self.error_count == 0 + + @property + def is_multi_step(self) -> bool: + """True if this trajectory involved multiple turns with tool use.""" + return self.message_count >= 2 and self.tool_call_count >= 1 + + def to_chat_format(self) -> list[dict[str, str]]: + """Convert trajectory to chat-format messages for training. + + Interleaves messages and tool-call results as assistant/tool turns. + """ + chat: list[dict[str, str]] = [] + # Merge all entries by timestamp and emit in order + all_entries = sorted( + self.messages + self.tool_calls + self.decisions, + key=lambda e: e.get("timestamp", ""), + ) + for entry in all_entries: + etype = entry.get("type") + if etype == "message": + role = "user" if entry.get("role") == "user" else "assistant" + content = entry.get("content", "") + if content: + chat.append({"role": role, "content": content}) + elif etype == "tool_call": + tool = entry.get("tool", "unknown") + result = entry.get("result", "") + chat.append( + { + "role": "assistant", + "content": f"[tool:{tool}] {result}", + } + ) + elif etype == "decision": + decision = entry.get("decision", "") + if decision: + chat.append({"role": "assistant", "content": f"[decided] {decision}"}) + return chat + + +class TrajectoryExporter: + """Reads session JSONL logs and yields Trajectory objects for a date range.""" + + def __init__(self, logs_dir: str | Path | None = None, repo_root: str | Path | None = None): + if repo_root is None: + repo_root = Path(__file__).resolve().parent.parent.parent + self._repo_root = Path(repo_root) + + if logs_dir is None: + self._logs_dir = self._repo_root / _LOGS_DIR_DEFAULT + else: + self._logs_dir = Path(logs_dir) + + def export_week(self, weeks_ago: int = 0) -> list[Trajectory]: + """Export all trajectories from the specified week. + + Args: + weeks_ago: 0 = current week, 1 = last week, etc. + + Returns: + List of Trajectory objects extracted from session logs. + """ + now = datetime.now(tz=UTC) + # Week boundaries: Mon–Sun + days_since_monday = now.weekday() + week_start = (now - timedelta(days=days_since_monday + 7 * weeks_ago)).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + week_end = week_start + timedelta(days=7) + + logger.info( + "Exporting trajectories for week %s–%s", + week_start.date().isoformat(), + week_end.date().isoformat(), + ) + + trajectories: list[Trajectory] = [] + log_files = sorted(self._logs_dir.glob(_SESSION_GLOB)) + + for log_file in log_files: + # Parse date from filename: session_YYYY-MM-DD.jsonl + try: + date_str = log_file.stem.removeprefix("session_") + file_date = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=UTC) + except ValueError: + logger.debug("Skipping non-date session file: %s", log_file.name) + continue + + if not (week_start <= file_date < week_end): + continue + + file_trajectories = self._extract_from_file(log_file) + trajectories.extend(file_trajectories) + logger.info( + "Extracted %d trajectories from %s", len(file_trajectories), log_file.name + ) + + logger.info("Total trajectories exported: %d", len(trajectories)) + return trajectories + + def _extract_from_file(self, log_file: Path) -> list[Trajectory]: + """Parse a single session JSONL file into trajectories. + + Groups entries into trajectories by finding natural conversation + boundaries (gaps of inactivity or topic shifts in the message stream). + """ + entries: list[dict[str, Any]] = [] + try: + with open(log_file) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + logger.debug("Skipping malformed JSON line in %s", log_file.name) + except OSError as exc: + logger.warning("Could not read %s: %s", log_file, exc) + return [] + + if not entries: + return [] + + date_str = log_file.stem.removeprefix("session_") + return self._segment_trajectories(entries, date_str) + + def _segment_trajectories( + self, entries: list[dict[str, Any]], session_date: str + ) -> list[Trajectory]: + """Split a flat list of session entries into discrete trajectories. + + Segmentation rule: start a new trajectory when: + - A user message follows a Timmy message (new conversation turn) + - More than 5 minutes have elapsed between entries + + This produces training examples that are coherent task attempts. + """ + if not entries: + return [] + + trajectories: list[Trajectory] = [] + current_entries: list[dict[str, Any]] = [] + prev_ts: datetime | None = None + _SEGMENT_GAP_MINUTES = 5 + + def _flush() -> None: + if current_entries: + traj = _build_trajectory(current_entries, session_date) + if traj.message_count > 0: + trajectories.append(traj) + + for entry in entries: + ts_raw = entry.get("timestamp", "") + try: + ts = datetime.fromisoformat(ts_raw.replace("Z", "+00:00")) + except (ValueError, AttributeError): + ts = None + + # Time-gap segmentation + if ts and prev_ts and (ts - prev_ts).total_seconds() > _SEGMENT_GAP_MINUTES * 60: + _flush() + current_entries = [] + + # New-turn segmentation: user message after assistant turn + etype = entry.get("type") + erole = entry.get("role") + if etype == "message" and erole == "user" and current_entries: + # Check if previous non-error entry was a Timmy message + for prev in reversed(current_entries): + if prev.get("type") == "message": + if prev.get("role") == "timmy": + _flush() + current_entries = [] + break + + current_entries.append(entry) + if ts: + prev_ts = ts + + _flush() + return trajectories + + +def _build_trajectory(entries: list[dict[str, Any]], session_date: str) -> Trajectory: + """Build a Trajectory from a flat list of entries.""" + messages = [e for e in entries if e.get("type") == "message"] + tool_calls = [e for e in entries if e.get("type") == "tool_call"] + errors = [e for e in entries if e.get("type") == "error"] + decisions = [e for e in entries if e.get("type") == "decision"] + + timestamps = [e.get("timestamp", "") for e in entries if e.get("timestamp")] + started_at = min(timestamps) if timestamps else "" + ended_at = max(timestamps) if timestamps else "" + + return Trajectory( + session_date=session_date, + started_at=started_at, + ended_at=ended_at, + messages=messages, + tool_calls=tool_calls, + errors=errors, + decisions=decisions, + )