This commit was merged in pull request #1118.
This commit is contained in:
550
tests/unit/test_retrain_loop.py
Normal file
550
tests/unit/test_retrain_loop.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""Unit tests for the AutoLoRA continuous improvement loop.
|
||||
|
||||
Covers trajectory extraction, quality filtering, dataset management,
|
||||
and the retrain orchestrator.
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from timmy_automations.retrain.quality_filter import QualityFilter, TrajectoryQuality
|
||||
from timmy_automations.retrain.retrain import RetrainOrchestrator
|
||||
from timmy_automations.retrain.training_dataset import TrainingDataset
|
||||
from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog
|
||||
from timmy_automations.retrain.trajectory_exporter import Trajectory, TrajectoryExporter
|
||||
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ts(offset_minutes: int = 0) -> str:
|
||||
"""Return an ISO timestamp offset from now."""
|
||||
return (datetime.now(tz=UTC) + timedelta(minutes=offset_minutes)).isoformat()
|
||||
|
||||
|
||||
def _make_session_log(entries: list[dict], date_str: str, tmp_path: Path) -> Path:
|
||||
"""Write session JSONL entries to a temp log file."""
|
||||
log_dir = tmp_path / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file = log_dir / f"session_{date_str}.jsonl"
|
||||
with open(log_file, "w") as f:
|
||||
for entry in entries:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
return log_file
|
||||
|
||||
|
||||
def _user_msg(content: str, offset: int = 0) -> dict:
|
||||
return {"type": "message", "role": "user", "content": content, "timestamp": _ts(offset)}
|
||||
|
||||
|
||||
def _timmy_msg(content: str, confidence: float | None = None, offset: int = 0) -> dict:
|
||||
entry = {"type": "message", "role": "timmy", "content": content, "timestamp": _ts(offset)}
|
||||
if confidence is not None:
|
||||
entry["confidence"] = confidence
|
||||
return entry
|
||||
|
||||
|
||||
def _tool_call(tool: str = "bash", result: str = "ok", offset: int = 0) -> dict:
|
||||
return {
|
||||
"type": "tool_call",
|
||||
"tool": tool,
|
||||
"args": {},
|
||||
"result": result,
|
||||
"timestamp": _ts(offset),
|
||||
}
|
||||
|
||||
|
||||
def _error_entry(msg: str = "Something failed", offset: int = 0) -> dict:
|
||||
return {"type": "error", "error": msg, "timestamp": _ts(offset)}
|
||||
|
||||
|
||||
def _decision_entry(decision: str = "Use approach A", offset: int = 0) -> dict:
|
||||
return {"type": "decision", "decision": decision, "timestamp": _ts(offset)}
|
||||
|
||||
|
||||
# ── Trajectory dataclass tests ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTrajectory:
|
||||
def test_message_count(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("hi"), _timmy_msg("hello")],
|
||||
)
|
||||
assert t.message_count == 2
|
||||
|
||||
def test_tool_call_count(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
tool_calls=[_tool_call(), _tool_call()],
|
||||
)
|
||||
assert t.tool_call_count == 2
|
||||
|
||||
def test_has_successful_tool_call_when_no_errors(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
tool_calls=[_tool_call()],
|
||||
errors=[],
|
||||
)
|
||||
assert t.has_successful_tool_call is True
|
||||
|
||||
def test_has_successful_tool_call_false_when_errors(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
tool_calls=[_tool_call()],
|
||||
errors=[_error_entry()],
|
||||
)
|
||||
assert t.has_successful_tool_call is False
|
||||
|
||||
def test_is_multi_step(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("do it"), _timmy_msg("done")],
|
||||
tool_calls=[_tool_call()],
|
||||
)
|
||||
assert t.is_multi_step is True
|
||||
|
||||
def test_is_not_multi_step_single_message(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_timmy_msg("hello")],
|
||||
tool_calls=[],
|
||||
)
|
||||
assert t.is_multi_step is False
|
||||
|
||||
def test_to_chat_format_ordering(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("question", offset=0), _timmy_msg("answer", offset=2)],
|
||||
tool_calls=[_tool_call(offset=1)],
|
||||
)
|
||||
chat = t.to_chat_format()
|
||||
roles = [m["role"] for m in chat]
|
||||
assert "user" in roles
|
||||
assert "assistant" in roles
|
||||
|
||||
def test_to_chat_format_empty_content_skipped(self):
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg(""), _timmy_msg("response")],
|
||||
)
|
||||
chat = t.to_chat_format()
|
||||
# Empty user message should be skipped
|
||||
assert all(m["content"] for m in chat)
|
||||
|
||||
|
||||
# ── TrajectoryExporter tests ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTrajectoryExporter:
|
||||
def test_export_empty_logs_dir(self, tmp_path):
|
||||
(tmp_path / "logs").mkdir()
|
||||
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||
result = exporter.export_week(weeks_ago=0)
|
||||
assert result == []
|
||||
|
||||
def test_export_reads_session_files(self, tmp_path):
|
||||
# Write a session file for this week
|
||||
today = datetime.now(tz=UTC)
|
||||
date_str = today.strftime("%Y-%m-%d")
|
||||
entries = [
|
||||
_user_msg("tell me about Python"),
|
||||
_timmy_msg("Python is great"),
|
||||
]
|
||||
_make_session_log(entries, date_str, tmp_path)
|
||||
|
||||
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||
result = exporter.export_week(weeks_ago=0)
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_export_skips_old_sessions(self, tmp_path):
|
||||
# Write a session file for 3 weeks ago
|
||||
three_weeks_ago = datetime.now(tz=UTC) - timedelta(weeks=3)
|
||||
date_str = three_weeks_ago.strftime("%Y-%m-%d")
|
||||
entries = [_user_msg("old message"), _timmy_msg("old response")]
|
||||
_make_session_log(entries, date_str, tmp_path)
|
||||
|
||||
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||
# Request current week — should not include 3-week-old data
|
||||
result = exporter.export_week(weeks_ago=0)
|
||||
assert result == []
|
||||
|
||||
def test_export_segments_by_gap(self, tmp_path):
|
||||
today = datetime.now(tz=UTC)
|
||||
date_str = today.strftime("%Y-%m-%d")
|
||||
|
||||
# Two conversations separated by 10 minutes
|
||||
t1 = (today - timedelta(minutes=15)).isoformat()
|
||||
t2 = (today - timedelta(minutes=14)).isoformat()
|
||||
t3 = (today - timedelta(minutes=2)).isoformat()
|
||||
t4 = (today - timedelta(minutes=1)).isoformat()
|
||||
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "first q", "timestamp": t1},
|
||||
{"type": "message", "role": "timmy", "content": "first a", "timestamp": t2},
|
||||
{"type": "message", "role": "user", "content": "second q", "timestamp": t3},
|
||||
{"type": "message", "role": "timmy", "content": "second a", "timestamp": t4},
|
||||
]
|
||||
_make_session_log(entries, date_str, tmp_path)
|
||||
|
||||
exporter = TrajectoryExporter(logs_dir=tmp_path / "logs", repo_root=tmp_path)
|
||||
result = exporter.export_week(weeks_ago=0)
|
||||
# Should have at least 1 trajectory (may be 1 or 2 depending on segmentation)
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_handles_malformed_log_file(self, tmp_path):
|
||||
log_dir = tmp_path / "logs"
|
||||
log_dir.mkdir()
|
||||
today = datetime.now(tz=UTC).strftime("%Y-%m-%d")
|
||||
(log_dir / f"session_{today}.jsonl").write_text("not json\n{}\n")
|
||||
|
||||
exporter = TrajectoryExporter(logs_dir=log_dir, repo_root=tmp_path)
|
||||
# Should not raise, just return empty or partial results
|
||||
result = exporter.export_week(weeks_ago=0)
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
# ── QualityFilter tests ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestQualityFilter:
|
||||
def _make_high_quality(self) -> Trajectory:
|
||||
return Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("do task"), _timmy_msg("done", confidence=0.9)],
|
||||
tool_calls=[_tool_call(), _tool_call()],
|
||||
errors=[],
|
||||
decisions=[_decision_entry()],
|
||||
)
|
||||
|
||||
def _make_medium_quality(self) -> Trajectory:
|
||||
return Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("hello"), _timmy_msg("hi")],
|
||||
tool_calls=[],
|
||||
errors=[],
|
||||
)
|
||||
|
||||
def _make_low_quality(self) -> Trajectory:
|
||||
return Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_timmy_msg("oops")], # No user message
|
||||
errors=[_error_entry()],
|
||||
)
|
||||
|
||||
def test_high_quality_classification(self):
|
||||
qf = QualityFilter()
|
||||
result = qf.assess(self._make_high_quality())
|
||||
assert result.quality == TrajectoryQuality.HIGH
|
||||
assert result.score >= 4.0
|
||||
assert result.is_trainable
|
||||
|
||||
def test_medium_quality_classification(self):
|
||||
qf = QualityFilter()
|
||||
result = qf.assess(self._make_medium_quality())
|
||||
assert result.quality == TrajectoryQuality.MEDIUM
|
||||
assert result.is_trainable
|
||||
|
||||
def test_low_quality_no_user_message(self):
|
||||
qf = QualityFilter()
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_timmy_msg("random")],
|
||||
)
|
||||
result = qf.assess(t)
|
||||
assert result.quality == TrajectoryQuality.LOW
|
||||
assert not result.is_trainable
|
||||
|
||||
def test_error_penalizes_score(self):
|
||||
qf = QualityFilter()
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("go"), _timmy_msg("fail")],
|
||||
tool_calls=[_tool_call()],
|
||||
errors=[_error_entry(), _error_entry()],
|
||||
)
|
||||
result = qf.assess(t)
|
||||
assert result.score < qf.assess(self._make_high_quality()).score
|
||||
|
||||
def test_low_confidence_penalizes_score(self):
|
||||
qf = QualityFilter()
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("q"), _timmy_msg("a", confidence=0.2)],
|
||||
)
|
||||
result = qf.assess(t)
|
||||
assert result.score < 1.0
|
||||
|
||||
def test_filter_returns_stats(self):
|
||||
qf = QualityFilter()
|
||||
trajectories = [
|
||||
self._make_high_quality(),
|
||||
self._make_medium_quality(),
|
||||
self._make_low_quality(),
|
||||
]
|
||||
trainable, stats = qf.filter(trajectories)
|
||||
assert stats["total"] == 3
|
||||
assert stats["accepted"] == len(trainable)
|
||||
assert stats["high"] + stats["medium"] + stats["low"] == 3
|
||||
|
||||
def test_filter_empty_list(self):
|
||||
qf = QualityFilter()
|
||||
trainable, stats = qf.filter([])
|
||||
assert trainable == []
|
||||
assert stats["total"] == 0
|
||||
assert stats["accepted"] == 0
|
||||
|
||||
|
||||
# ── TrainingDataset tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTrainingDataset:
|
||||
def _make_result(self, quality=TrajectoryQuality.HIGH, score=5.0) -> object:
|
||||
from timmy_automations.retrain.quality_filter import QualityResult
|
||||
|
||||
t = Trajectory(
|
||||
session_date="2026-03-17",
|
||||
started_at=_ts(-5),
|
||||
ended_at=_ts(),
|
||||
messages=[_user_msg("do it"), _timmy_msg("done")],
|
||||
tool_calls=[_tool_call()],
|
||||
)
|
||||
return QualityResult(trajectory=t, quality=quality, score=score, reasons=[])
|
||||
|
||||
def test_count_empty_dataset(self, tmp_path):
|
||||
ds = TrainingDataset(
|
||||
dataset_path=".loop/retrain/training_data.jsonl",
|
||||
repo_root=tmp_path,
|
||||
)
|
||||
assert ds.count() == 0
|
||||
|
||||
def test_append_adds_examples(self, tmp_path):
|
||||
ds = TrainingDataset(repo_root=tmp_path)
|
||||
result = ds.append([self._make_result()], "2026-W12")
|
||||
assert result.new_examples == 1
|
||||
assert result.total_examples == 1
|
||||
assert ds.count() == 1
|
||||
|
||||
def test_append_idempotent(self, tmp_path):
|
||||
ds = TrainingDataset(repo_root=tmp_path)
|
||||
r = self._make_result()
|
||||
ds.append([r], "2026-W12")
|
||||
result2 = ds.append([r], "2026-W12")
|
||||
# Same trajectory shouldn't be added twice
|
||||
assert result2.new_examples == 0
|
||||
assert ds.count() == 1
|
||||
|
||||
def test_append_different_weeks(self, tmp_path):
|
||||
ds = TrainingDataset(repo_root=tmp_path)
|
||||
r1 = self._make_result()
|
||||
ds.append([r1], "2026-W11")
|
||||
ds.append([r1], "2026-W12")
|
||||
# Different week tags = different records
|
||||
assert ds.count() == 2
|
||||
|
||||
def test_dataset_file_is_valid_jsonl(self, tmp_path):
|
||||
ds = TrainingDataset(repo_root=tmp_path)
|
||||
ds.append([self._make_result()], "2026-W12")
|
||||
with open(ds.dataset_path) as f:
|
||||
lines = [l.strip() for l in f if l.strip()]
|
||||
assert len(lines) == 1
|
||||
record = json.loads(lines[0])
|
||||
assert "messages" in record
|
||||
assert "week" in record
|
||||
assert "quality" in record
|
||||
|
||||
def test_index_updated_after_append(self, tmp_path):
|
||||
ds = TrainingDataset(repo_root=tmp_path)
|
||||
ds.append([self._make_result()], "2026-W12")
|
||||
index_path = tmp_path / ".loop" / "retrain" / "dataset_index.json"
|
||||
assert index_path.exists()
|
||||
index = json.loads(index_path.read_text())
|
||||
assert index["total_examples"] == 1
|
||||
assert "2026-W12" in index["weeks"]
|
||||
|
||||
|
||||
# ── TrainingLog tests ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTrainingLog:
|
||||
def _make_metrics(self, iteration: int = 1) -> CycleMetrics:
|
||||
return CycleMetrics(
|
||||
iteration=iteration,
|
||||
week="2026-W12",
|
||||
ran_at=datetime.now(tz=UTC).isoformat(),
|
||||
trajectories_total=10,
|
||||
trajectories_high=5,
|
||||
trajectories_medium=3,
|
||||
trajectories_low=2,
|
||||
trajectories_accepted=8,
|
||||
examples_added=5,
|
||||
dataset_total=5,
|
||||
train_status="completed",
|
||||
train_loss=1.2345,
|
||||
train_duration_seconds=120.5,
|
||||
adapter_path=".loop/retrain/adapters/iter_0001/adapters.npz",
|
||||
model_name="hermes4-14b-ft-0001",
|
||||
notes="First fine-tune cycle complete",
|
||||
)
|
||||
|
||||
def test_next_iteration_starts_at_1(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
assert log.next_iteration() == 1
|
||||
|
||||
def test_next_iteration_increments(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
log.record(self._make_metrics(iteration=1))
|
||||
assert log.next_iteration() == 2
|
||||
|
||||
def test_record_creates_log_file(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
log.record(self._make_metrics())
|
||||
assert log.log_path.exists()
|
||||
|
||||
def test_load_all_returns_records(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
log.record(self._make_metrics(iteration=1))
|
||||
log.record(self._make_metrics(iteration=2))
|
||||
entries = log.load_all()
|
||||
assert len(entries) == 2
|
||||
assert entries[0]["iteration"] == 1
|
||||
|
||||
def test_latest_returns_last_entry(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
log.record(self._make_metrics(iteration=1))
|
||||
log.record(self._make_metrics(iteration=2))
|
||||
latest = log.latest()
|
||||
assert latest is not None
|
||||
assert latest["iteration"] == 2
|
||||
|
||||
def test_latest_returns_none_when_empty(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
assert log.latest() is None
|
||||
|
||||
def test_summary_markdown_written(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
log.record(self._make_metrics())
|
||||
summary_path = tmp_path / ".loop" / "retrain" / "training_log.md"
|
||||
assert summary_path.exists()
|
||||
content = summary_path.read_text()
|
||||
assert "AutoLoRA Training Log" in content
|
||||
assert "2026-W12" in content
|
||||
assert "completed" in content
|
||||
|
||||
def test_skill_accuracy_in_summary(self, tmp_path):
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
m = self._make_metrics()
|
||||
m.skill_accuracy = {"tool_calling": 0.85, "reasoning": 0.72}
|
||||
log.record(m)
|
||||
content = (tmp_path / ".loop" / "retrain" / "training_log.md").read_text()
|
||||
assert "tool_calling" in content
|
||||
assert "reasoning" in content
|
||||
|
||||
|
||||
# ── RetrainOrchestrator integration tests ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestRetrainOrchestrator:
|
||||
def test_run_dry_run_no_data(self, tmp_path):
|
||||
"""Dry run with no session logs should complete without errors."""
|
||||
(tmp_path / "logs").mkdir(parents=True)
|
||||
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||
result = orc.run(weeks_ago=0)
|
||||
assert result.train_status in ("skipped",)
|
||||
assert result.examples_added == 0
|
||||
assert result.iteration == 1
|
||||
|
||||
def test_run_creates_log_entry(self, tmp_path):
|
||||
(tmp_path / "logs").mkdir(parents=True)
|
||||
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||
orc.run(weeks_ago=0)
|
||||
log = TrainingLog(repo_root=tmp_path)
|
||||
entries = log.load_all()
|
||||
assert len(entries) == 1
|
||||
|
||||
def test_run_with_session_data(self, tmp_path):
|
||||
"""Run with actual session data — should export, filter, and log."""
|
||||
today = datetime.now(tz=UTC)
|
||||
date_str = today.strftime("%Y-%m-%d")
|
||||
entries = [
|
||||
_user_msg("deploy the service", offset=-10),
|
||||
_tool_call("bash", "deployed successfully", offset=-9),
|
||||
_tool_call("bash", "health check ok", offset=-8),
|
||||
_timmy_msg("Service deployed and healthy", confidence=0.92, offset=-7),
|
||||
_user_msg("run the tests", offset=-6),
|
||||
_tool_call("bash", "All tests passed", offset=-5),
|
||||
_timmy_msg("All 42 tests passed", confidence=0.95, offset=-4),
|
||||
]
|
||||
_make_session_log(entries, date_str, tmp_path)
|
||||
|
||||
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||
result = orc.run(weeks_ago=0)
|
||||
|
||||
assert result.trajectories_exported >= 1
|
||||
assert result.iteration == 1
|
||||
# In dry_run mode, fine-tune is skipped but trajectories should be processed
|
||||
assert result.train_status == "skipped"
|
||||
|
||||
def test_iteration_increments_on_second_run(self, tmp_path):
|
||||
(tmp_path / "logs").mkdir(parents=True)
|
||||
orc = RetrainOrchestrator(repo_root=tmp_path, dry_run=True)
|
||||
r1 = orc.run(weeks_ago=0)
|
||||
r2 = orc.run(weeks_ago=0)
|
||||
assert r2.iteration == r1.iteration + 1
|
||||
|
||||
def test_automations_json_has_retrain_entry(self):
|
||||
"""Verify the retrain automation is registered in automations.json."""
|
||||
config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json"
|
||||
assert config_path.exists()
|
||||
manifest = json.loads(config_path.read_text())
|
||||
ids = [a["id"] for a in manifest.get("automations", [])]
|
||||
assert "retrain" in ids
|
||||
|
||||
def test_retrain_automation_config(self):
|
||||
"""Verify retrain automation has correct schedule and config."""
|
||||
config_path = _REPO_ROOT / "timmy_automations" / "config" / "automations.json"
|
||||
manifest = json.loads(config_path.read_text())
|
||||
retrain = next(a for a in manifest["automations"] if a["id"] == "retrain")
|
||||
assert retrain["schedule"] == "weekly_sunday"
|
||||
assert retrain["trigger"] == "scheduled"
|
||||
assert retrain["config"]["base_model"] == "hermes4-14b"
|
||||
assert retrain["config"]["weeks_ago"] == 1
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
@@ -4,7 +4,7 @@
|
||||
"_health_snapshot": {
|
||||
"note": "Quick health check before coding — CI, P0/P1 issues, flakiness"
|
||||
},
|
||||
"last_updated": "2026-03-21",
|
||||
"last_updated": "2026-03-23",
|
||||
"automations": [
|
||||
{
|
||||
"id": "cycle_retro",
|
||||
@@ -268,6 +268,36 @@
|
||||
"ci_timeout_seconds": 5
|
||||
},
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"id": "retrain",
|
||||
"name": "AutoLoRA Continuous Improvement Loop",
|
||||
"description": "Weekly sovereignty loop — exports trajectories, filters quality, appends to training dataset, triggers LoRA fine-tune, loads new adapter, and logs iteration metrics",
|
||||
"script": "timmy_automations/retrain/retrain.py",
|
||||
"category": "autolora",
|
||||
"enabled": true,
|
||||
"trigger": "scheduled",
|
||||
"schedule": "weekly_sunday",
|
||||
"executable": "python3",
|
||||
"epic": "#1091",
|
||||
"pipeline": "AutoLoRA Sovereignty Loop (Step 6 of 7)",
|
||||
"config": {
|
||||
"weeks_ago": 1,
|
||||
"base_model": "hermes4-14b",
|
||||
"dry_run": false,
|
||||
"logs_dir": "logs",
|
||||
"dataset_path": ".loop/retrain/training_data.jsonl",
|
||||
"adapter_dir": ".loop/retrain/adapters",
|
||||
"training_log_path": ".loop/retrain/training_log.jsonl",
|
||||
"training_summary_path": ".loop/retrain/training_log.md"
|
||||
},
|
||||
"outputs": [
|
||||
".loop/retrain/training_data.jsonl",
|
||||
".loop/retrain/dataset_index.json",
|
||||
".loop/retrain/training_log.jsonl",
|
||||
".loop/retrain/training_log.md",
|
||||
".loop/retrain/adapters/"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
26
timmy_automations/retrain/__init__.py
Normal file
26
timmy_automations/retrain/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""AutoLoRA continuous improvement loop — sovereignty engine for Timmy.
|
||||
|
||||
Implements the weekly retrain cycle:
|
||||
Work → Record trajectories → Export weekly → Filter quality
|
||||
→ LoRA fine-tune → Load adapter → Model improves → Repeat
|
||||
|
||||
Epic: #1091 — Project Bannerlord
|
||||
Pipeline: AutoLoRA Sovereignty Loop (Step 6 of 7)
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from timmy_automations.retrain.quality_filter import QualityFilter, TrajectoryQuality
|
||||
from timmy_automations.retrain.retrain import RetrainOrchestrator, RetrainResult
|
||||
from timmy_automations.retrain.training_dataset import TrainingDataset
|
||||
from timmy_automations.retrain.training_log import TrainingLog
|
||||
from timmy_automations.retrain.trajectory_exporter import TrajectoryExporter
|
||||
|
||||
__all__ = [
|
||||
"QualityFilter",
|
||||
"RetrainOrchestrator",
|
||||
"RetrainResult",
|
||||
"TrainingDataset",
|
||||
"TrainingLog",
|
||||
"TrajectoryExporter",
|
||||
"TrajectoryQuality",
|
||||
]
|
||||
262
timmy_automations/retrain/lora_trainer.py
Normal file
262
timmy_automations/retrain/lora_trainer.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""LoRA trainer — triggers fine-tune job and loads the resulting adapter.
|
||||
|
||||
Supports two backends:
|
||||
1. mlx-lm (default, Apple Silicon) — `mlx_lm.lora` CLI
|
||||
2. Ollama create (adapter packaging into a new Ollama model)
|
||||
|
||||
Graceful degradation: if neither backend is available, logs a warning
|
||||
and returns a skipped result — the rest of the loop continues.
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_BASE_MODEL = "hermes4-14b"
|
||||
_DEFAULT_ADAPTER_DIR = ".loop/retrain/adapters"
|
||||
_MLX_LM_BIN = "mlx_lm.lora"
|
||||
_OLLAMA_BIN = "ollama"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainResult:
|
||||
"""Result of a LoRA fine-tune run."""
|
||||
|
||||
status: str # "completed" | "skipped" | "failed"
|
||||
adapter_path: str | None
|
||||
model_name: str | None
|
||||
iteration: int
|
||||
duration_seconds: float
|
||||
message: str
|
||||
train_loss: float | None = None
|
||||
|
||||
|
||||
class LoRATrainer:
|
||||
"""Orchestrates LoRA fine-tuning and adapter loading.
|
||||
|
||||
Workflow:
|
||||
1. Run mlx_lm.lora fine-tune on the training dataset
|
||||
2. Save the resulting adapter to .loop/retrain/adapters/<iteration>/
|
||||
3. Create (or update) an Ollama model that uses the new adapter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model: str = _DEFAULT_BASE_MODEL,
|
||||
adapter_dir: str | Path | None = None,
|
||||
repo_root: str | Path | None = None,
|
||||
dry_run: bool = False,
|
||||
):
|
||||
if repo_root is None:
|
||||
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
self._repo_root = Path(repo_root)
|
||||
|
||||
self._base_model = base_model
|
||||
self._adapter_dir = self._repo_root / (adapter_dir or _DEFAULT_ADAPTER_DIR)
|
||||
self._adapter_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._dry_run = dry_run
|
||||
|
||||
def train(self, dataset_path: Path, iteration: int) -> TrainResult:
|
||||
"""Run LoRA fine-tuning on the dataset.
|
||||
|
||||
Args:
|
||||
dataset_path: Path to the JSONL training dataset.
|
||||
iteration: Current fine-tune iteration number (used for naming).
|
||||
|
||||
Returns:
|
||||
TrainResult with status, adapter path, and metrics.
|
||||
"""
|
||||
started = datetime.now(tz=UTC)
|
||||
|
||||
if not dataset_path.exists() or dataset_path.stat().st_size == 0:
|
||||
return TrainResult(
|
||||
status="skipped",
|
||||
adapter_path=None,
|
||||
model_name=None,
|
||||
iteration=iteration,
|
||||
duration_seconds=0.0,
|
||||
message="Training dataset is empty — skipping fine-tune",
|
||||
)
|
||||
|
||||
if self._dry_run:
|
||||
logger.info("[dry-run] Would fine-tune %s on %s", self._base_model, dataset_path)
|
||||
adapter_path = self._adapter_dir / f"iter_{iteration:04d}" / "adapters.npz"
|
||||
return TrainResult(
|
||||
status="skipped",
|
||||
adapter_path=str(adapter_path),
|
||||
model_name=f"{self._base_model}-ft-{iteration:04d}",
|
||||
iteration=iteration,
|
||||
duration_seconds=0.0,
|
||||
message="dry-run mode — no training performed",
|
||||
)
|
||||
|
||||
# Determine which backend is available
|
||||
if shutil.which(_MLX_LM_BIN):
|
||||
return self._train_mlx(dataset_path, iteration, started)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s not found — skipping LoRA fine-tune (install mlx-lm to enable)",
|
||||
_MLX_LM_BIN,
|
||||
)
|
||||
return TrainResult(
|
||||
status="skipped",
|
||||
adapter_path=None,
|
||||
model_name=None,
|
||||
iteration=iteration,
|
||||
duration_seconds=0.0,
|
||||
message=(
|
||||
f"{_MLX_LM_BIN} not available. "
|
||||
"Install mlx-lm on Apple Silicon to enable LoRA fine-tuning."
|
||||
),
|
||||
)
|
||||
|
||||
def _train_mlx(
|
||||
self, dataset_path: Path, iteration: int, started: datetime
|
||||
) -> TrainResult:
|
||||
"""Run mlx_lm.lora fine-tune."""
|
||||
adapter_out = self._adapter_dir / f"iter_{iteration:04d}"
|
||||
adapter_out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cmd = [
|
||||
_MLX_LM_BIN,
|
||||
"--model", self._base_model,
|
||||
"--data", str(dataset_path),
|
||||
"--adapter-path", str(adapter_out),
|
||||
"--train",
|
||||
"--iters", "100",
|
||||
"--batch-size", "1",
|
||||
"--learning-rate", "1e-5",
|
||||
]
|
||||
|
||||
logger.info("Starting mlx-lm LoRA fine-tune: iteration %d", iteration)
|
||||
logger.info("Command: %s", " ".join(cmd))
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3600, # 1 hour max
|
||||
env={**os.environ, "PYTHONUNBUFFERED": "1"},
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||
return TrainResult(
|
||||
status="failed",
|
||||
adapter_path=None,
|
||||
model_name=None,
|
||||
iteration=iteration,
|
||||
duration_seconds=duration,
|
||||
message="Fine-tune timed out after 1 hour",
|
||||
)
|
||||
except Exception as exc:
|
||||
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||
return TrainResult(
|
||||
status="failed",
|
||||
adapter_path=None,
|
||||
model_name=None,
|
||||
iteration=iteration,
|
||||
duration_seconds=duration,
|
||||
message=f"Fine-tune subprocess error: {exc}",
|
||||
)
|
||||
|
||||
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error("mlx-lm fine-tune failed: %s", result.stderr[:500])
|
||||
return TrainResult(
|
||||
status="failed",
|
||||
adapter_path=None,
|
||||
model_name=None,
|
||||
iteration=iteration,
|
||||
duration_seconds=duration,
|
||||
message=f"mlx_lm.lora exited {result.returncode}: {result.stderr[:300]}",
|
||||
)
|
||||
|
||||
# Parse final train loss from stdout if available
|
||||
train_loss = _parse_train_loss(result.stdout)
|
||||
|
||||
adapter_file = adapter_out / "adapters.npz"
|
||||
model_name = f"{self._base_model}-ft-{iteration:04d}"
|
||||
|
||||
# Attempt to register with Ollama
|
||||
ollama_ok = self._register_ollama_adapter(adapter_out, model_name)
|
||||
if not ollama_ok:
|
||||
logger.warning("Ollama adapter registration failed — adapter saved locally")
|
||||
|
||||
logger.info(
|
||||
"Fine-tune complete: iteration=%d loss=%.4f duration=%.1fs adapter=%s",
|
||||
iteration,
|
||||
train_loss or 0.0,
|
||||
duration,
|
||||
adapter_file,
|
||||
)
|
||||
|
||||
return TrainResult(
|
||||
status="completed",
|
||||
adapter_path=str(adapter_file),
|
||||
model_name=model_name,
|
||||
iteration=iteration,
|
||||
duration_seconds=duration,
|
||||
message=f"LoRA fine-tune completed successfully in {duration:.0f}s",
|
||||
train_loss=train_loss,
|
||||
)
|
||||
|
||||
def _register_ollama_adapter(self, adapter_dir: Path, model_name: str) -> bool:
|
||||
"""Create an Ollama model entry for the new adapter.
|
||||
|
||||
Writes a minimal Modelfile and runs `ollama create`.
|
||||
"""
|
||||
if not shutil.which(_OLLAMA_BIN):
|
||||
logger.debug("Ollama not found — skipping adapter registration")
|
||||
return False
|
||||
|
||||
modelfile_content = (
|
||||
f"FROM {self._base_model}\n"
|
||||
f"ADAPTER {adapter_dir}\n"
|
||||
)
|
||||
modelfile_path = adapter_dir / "Modelfile"
|
||||
try:
|
||||
modelfile_path.write_text(modelfile_content)
|
||||
result = subprocess.run(
|
||||
[_OLLAMA_BIN, "create", model_name, "-f", str(modelfile_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info("Ollama model registered: %s", model_name)
|
||||
return True
|
||||
else:
|
||||
logger.warning("ollama create failed: %s", result.stderr[:200])
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.warning("Ollama adapter registration error: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
def _parse_train_loss(stdout: str) -> float | None:
|
||||
"""Extract the final training loss from mlx-lm stdout."""
|
||||
loss: float | None = None
|
||||
for line in stdout.splitlines():
|
||||
line_lower = line.lower()
|
||||
if "train loss" in line_lower or "loss:" in line_lower:
|
||||
parts = line.split()
|
||||
for i, part in enumerate(parts):
|
||||
if "loss" in part.lower() and i + 1 < len(parts):
|
||||
try:
|
||||
loss = float(parts[i + 1].strip(",:"))
|
||||
except ValueError:
|
||||
pass
|
||||
return loss
|
||||
172
timmy_automations/retrain/quality_filter.py
Normal file
172
timmy_automations/retrain/quality_filter.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Quality filter — keeps only high-value trajectories for LoRA training.
|
||||
|
||||
Criteria for a high-quality training example:
|
||||
1. Tool calls succeeded (tool calls present, no error entries)
|
||||
2. Multi-step tasks completed (≥2 messages + ≥1 tool call)
|
||||
3. No low-confidence signals (confidence < 0.5 on any Timmy message)
|
||||
4. Minimum meaningful exchange (≥1 user message + ≥1 Timmy message)
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
from timmy_automations.retrain.trajectory_exporter import Trajectory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MIN_CONFIDENCE = 0.5
|
||||
|
||||
|
||||
class TrajectoryQuality(StrEnum):
|
||||
"""Quality classification for a trajectory."""
|
||||
|
||||
HIGH = "high" # Multi-step + tool success — ideal training data
|
||||
MEDIUM = "medium" # Single exchange, no errors — acceptable
|
||||
LOW = "low" # Error-prone or trivial — skip
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityResult:
|
||||
"""Result of quality assessment for a single trajectory."""
|
||||
|
||||
trajectory: Trajectory
|
||||
quality: TrajectoryQuality
|
||||
score: float
|
||||
reasons: list[str]
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return self.quality in (TrajectoryQuality.HIGH, TrajectoryQuality.MEDIUM)
|
||||
|
||||
|
||||
class QualityFilter:
|
||||
"""Filters trajectories to keep only those worth training on.
|
||||
|
||||
Scoring:
|
||||
- +1 pt: base score for any valid clean exchange (no errors)
|
||||
- +3 pts: multi-step task (≥2 messages + ≥1 tool call)
|
||||
- +2 pts: tool calls present and no errors
|
||||
- +1 pt: decision recorded (deliberate choice made)
|
||||
- -2 pts: any error entry
|
||||
- -1 pt: any low-confidence response (confidence < 0.5)
|
||||
|
||||
HIGH ≥ 4, MEDIUM 1–3, LOW ≤ 0
|
||||
"""
|
||||
|
||||
def __init__(self, min_confidence: float = _MIN_CONFIDENCE):
|
||||
self._min_confidence = min_confidence
|
||||
|
||||
def assess(self, trajectory: Trajectory) -> QualityResult:
|
||||
"""Score and classify a single trajectory."""
|
||||
score = 0.0
|
||||
reasons: list[str] = []
|
||||
|
||||
# Minimum viable exchange check
|
||||
user_msgs = [m for m in trajectory.messages if m.get("role") == "user"]
|
||||
timmy_msgs = [m for m in trajectory.messages if m.get("role") == "timmy"]
|
||||
|
||||
if not user_msgs or not timmy_msgs:
|
||||
return QualityResult(
|
||||
trajectory=trajectory,
|
||||
quality=TrajectoryQuality.LOW,
|
||||
score=0.0,
|
||||
reasons=["Missing user or assistant messages — not a valid exchange"],
|
||||
)
|
||||
|
||||
# Multi-step bonus
|
||||
if trajectory.is_multi_step:
|
||||
score += 3.0
|
||||
reasons.append(
|
||||
f"Multi-step task: {trajectory.message_count} messages, "
|
||||
f"{trajectory.tool_call_count} tool calls"
|
||||
)
|
||||
|
||||
# Base score for any clean exchange (user + timmy, no tool call required)
|
||||
if trajectory.error_count == 0:
|
||||
score += 1.0
|
||||
reasons.append("Clean exchange (no errors)")
|
||||
|
||||
# Tool call quality
|
||||
if trajectory.tool_call_count > 0:
|
||||
if trajectory.error_count == 0:
|
||||
score += 2.0
|
||||
reasons.append(
|
||||
f"All {trajectory.tool_call_count} tool call(s) succeeded"
|
||||
)
|
||||
else:
|
||||
score -= 2.0
|
||||
reasons.append(
|
||||
f"{trajectory.error_count} error(s) during {trajectory.tool_call_count} tool call(s)"
|
||||
)
|
||||
elif trajectory.error_count > 0:
|
||||
score -= 2.0
|
||||
reasons.append(f"{trajectory.error_count} error(s) with no tool calls")
|
||||
|
||||
# Decision bonus
|
||||
if trajectory.decisions:
|
||||
score += 1.0
|
||||
reasons.append(f"Decisions recorded: {len(trajectory.decisions)}")
|
||||
|
||||
# Confidence penalty
|
||||
low_conf = [
|
||||
m
|
||||
for m in timmy_msgs
|
||||
if m.get("confidence") is not None
|
||||
and m["confidence"] < self._min_confidence
|
||||
]
|
||||
if low_conf:
|
||||
score -= len(low_conf)
|
||||
reasons.append(
|
||||
f"{len(low_conf)} low-confidence response(s) (threshold={self._min_confidence})"
|
||||
)
|
||||
|
||||
# Classify
|
||||
if score >= 4.0:
|
||||
quality = TrajectoryQuality.HIGH
|
||||
elif score >= 1.0:
|
||||
quality = TrajectoryQuality.MEDIUM
|
||||
else:
|
||||
quality = TrajectoryQuality.LOW
|
||||
|
||||
return QualityResult(
|
||||
trajectory=trajectory,
|
||||
quality=quality,
|
||||
score=score,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
def filter(
|
||||
self, trajectories: list[Trajectory]
|
||||
) -> tuple[list[QualityResult], dict[str, int]]:
|
||||
"""Assess all trajectories and return trainable ones with stats.
|
||||
|
||||
Returns:
|
||||
(trainable_results, stats_dict) where stats_dict has keys
|
||||
'total', 'high', 'medium', 'low', 'accepted'.
|
||||
"""
|
||||
results = [self.assess(t) for t in trajectories]
|
||||
trainable = [r for r in results if r.is_trainable]
|
||||
|
||||
stats = {
|
||||
"total": len(results),
|
||||
"high": sum(1 for r in results if r.quality == TrajectoryQuality.HIGH),
|
||||
"medium": sum(1 for r in results if r.quality == TrajectoryQuality.MEDIUM),
|
||||
"low": sum(1 for r in results if r.quality == TrajectoryQuality.LOW),
|
||||
"accepted": len(trainable),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Quality filter: %d/%d accepted (high=%d medium=%d low=%d)",
|
||||
stats["accepted"],
|
||||
stats["total"],
|
||||
stats["high"],
|
||||
stats["medium"],
|
||||
stats["low"],
|
||||
)
|
||||
|
||||
return trainable, stats
|
||||
292
timmy_automations/retrain/retrain.py
Normal file
292
timmy_automations/retrain/retrain.py
Normal file
@@ -0,0 +1,292 @@
|
||||
#!/usr/bin/env python3
|
||||
"""AutoLoRA continuous improvement loop — the sovereignty retrain script.
|
||||
|
||||
Implements the weekly retrain cycle end-to-end:
|
||||
Work → Record trajectories → Export weekly → Filter quality
|
||||
→ LoRA fine-tune → Load adapter → Model improves → Repeat forever
|
||||
|
||||
Run:
|
||||
python3 timmy_automations/retrain/retrain.py
|
||||
python3 timmy_automations/retrain/retrain.py --dry-run
|
||||
python3 timmy_automations/retrain/retrain.py --weeks-ago 1
|
||||
|
||||
Epic: #1091 — Project Bannerlord
|
||||
Pipeline: AutoLoRA Sovereignty Loop (Step 6 of 7)
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Allow running directly from repo root
|
||||
_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_REPO_ROOT))
|
||||
|
||||
from timmy_automations.retrain.lora_trainer import LoRATrainer
|
||||
from timmy_automations.retrain.quality_filter import QualityFilter
|
||||
from timmy_automations.retrain.training_dataset import TrainingDataset
|
||||
from timmy_automations.retrain.training_log import CycleMetrics, TrainingLog
|
||||
from timmy_automations.retrain.trajectory_exporter import TrajectoryExporter
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-8s %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("retrain")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrainResult:
|
||||
"""Result of a complete retrain cycle."""
|
||||
|
||||
iteration: int
|
||||
week: str
|
||||
trajectories_exported: int
|
||||
trajectories_accepted: int
|
||||
examples_added: int
|
||||
dataset_total: int
|
||||
train_status: str
|
||||
adapter_path: str | None
|
||||
model_name: str | None
|
||||
train_loss: float | None
|
||||
duration_seconds: float
|
||||
notes: str
|
||||
|
||||
|
||||
class RetrainOrchestrator:
|
||||
"""Orchestrates the complete AutoLoRA continuous improvement loop.
|
||||
|
||||
Step 1: Export this week's conversation trajectories from session logs
|
||||
Step 2: Filter for high-quality exchanges
|
||||
Step 3: Append to the training dataset
|
||||
Step 4: Trigger LoRA fine-tune
|
||||
Step 5: Load the new adapter (via Ollama)
|
||||
Step 6: Log iteration, loss, skill accuracy
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model: str = "hermes4-14b",
|
||||
repo_root: str | Path | None = None,
|
||||
dry_run: bool = False,
|
||||
):
|
||||
if repo_root is None:
|
||||
repo_root = _REPO_ROOT
|
||||
self._repo_root = Path(repo_root)
|
||||
self._dry_run = dry_run
|
||||
|
||||
self.exporter = TrajectoryExporter(repo_root=self._repo_root)
|
||||
self.quality_filter = QualityFilter()
|
||||
self.dataset = TrainingDataset(repo_root=self._repo_root)
|
||||
self.trainer = LoRATrainer(
|
||||
base_model=base_model,
|
||||
repo_root=self._repo_root,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
self.log = TrainingLog(repo_root=self._repo_root)
|
||||
|
||||
def run(self, weeks_ago: int = 1) -> RetrainResult:
|
||||
"""Execute one complete retrain cycle.
|
||||
|
||||
Args:
|
||||
weeks_ago: Which week to process. 0 = current week (partial),
|
||||
1 = last week (default, Sunday night run), etc.
|
||||
|
||||
Returns:
|
||||
RetrainResult with full cycle summary.
|
||||
"""
|
||||
started = datetime.now(tz=UTC)
|
||||
iteration = self.log.next_iteration()
|
||||
|
||||
# Determine ISO week tag
|
||||
from datetime import timedelta
|
||||
now = datetime.now(tz=UTC)
|
||||
target_date = now - timedelta(weeks=weeks_ago)
|
||||
week_tag = f"{target_date.year}-W{target_date.isocalendar().week:02d}"
|
||||
|
||||
logger.info(
|
||||
"=== AutoLoRA Retrain Cycle %d | Week: %s | dry_run=%s ===",
|
||||
iteration,
|
||||
week_tag,
|
||||
self._dry_run,
|
||||
)
|
||||
|
||||
# Step 1: Export trajectories
|
||||
logger.info("Step 1: Exporting trajectories for %s...", week_tag)
|
||||
trajectories = self.exporter.export_week(weeks_ago=weeks_ago)
|
||||
logger.info("Exported %d raw trajectories", len(trajectories))
|
||||
|
||||
# Step 2: Quality filter
|
||||
logger.info("Step 2: Applying quality filter...")
|
||||
trainable, filter_stats = self.quality_filter.filter(trajectories)
|
||||
logger.info(
|
||||
"Quality filter: %d/%d accepted (high=%d medium=%d low=%d)",
|
||||
filter_stats["accepted"],
|
||||
filter_stats["total"],
|
||||
filter_stats["high"],
|
||||
filter_stats["medium"],
|
||||
filter_stats["low"],
|
||||
)
|
||||
|
||||
# Step 3: Append to dataset
|
||||
logger.info("Step 3: Appending to training dataset...")
|
||||
append_result = self.dataset.append(trainable, week_tag)
|
||||
logger.info(
|
||||
"Dataset: +%d new examples (%d total)",
|
||||
append_result.new_examples,
|
||||
append_result.total_examples,
|
||||
)
|
||||
|
||||
# Step 4: LoRA fine-tune
|
||||
logger.info("Step 4: Triggering LoRA fine-tune (iteration=%d)...", iteration)
|
||||
train_result = self.trainer.train(
|
||||
dataset_path=self.dataset.dataset_path,
|
||||
iteration=iteration,
|
||||
)
|
||||
logger.info(
|
||||
"Train result: status=%s loss=%s duration=%.1fs",
|
||||
train_result.status,
|
||||
train_result.train_loss,
|
||||
train_result.duration_seconds,
|
||||
)
|
||||
|
||||
# Step 5 & 6: Log cycle
|
||||
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
||||
metrics = CycleMetrics(
|
||||
iteration=iteration,
|
||||
week=week_tag,
|
||||
ran_at=started.isoformat(),
|
||||
trajectories_total=filter_stats["total"],
|
||||
trajectories_high=filter_stats["high"],
|
||||
trajectories_medium=filter_stats["medium"],
|
||||
trajectories_low=filter_stats["low"],
|
||||
trajectories_accepted=filter_stats["accepted"],
|
||||
examples_added=append_result.new_examples,
|
||||
dataset_total=append_result.total_examples,
|
||||
train_status=train_result.status,
|
||||
train_loss=train_result.train_loss,
|
||||
train_duration_seconds=train_result.duration_seconds,
|
||||
adapter_path=train_result.adapter_path,
|
||||
model_name=train_result.model_name,
|
||||
notes=train_result.message,
|
||||
)
|
||||
self.log.record(metrics)
|
||||
|
||||
result = RetrainResult(
|
||||
iteration=iteration,
|
||||
week=week_tag,
|
||||
trajectories_exported=len(trajectories),
|
||||
trajectories_accepted=filter_stats["accepted"],
|
||||
examples_added=append_result.new_examples,
|
||||
dataset_total=append_result.total_examples,
|
||||
train_status=train_result.status,
|
||||
adapter_path=train_result.adapter_path,
|
||||
model_name=train_result.model_name,
|
||||
train_loss=train_result.train_loss,
|
||||
duration_seconds=duration,
|
||||
notes=train_result.message,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"=== Cycle %d complete: status=%s examples_added=%d total=%.1fs ===",
|
||||
iteration,
|
||||
train_result.status,
|
||||
append_result.new_examples,
|
||||
duration,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _print_result(result: RetrainResult, as_json: bool = False) -> None:
|
||||
"""Print cycle result to stdout."""
|
||||
if as_json:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"iteration": result.iteration,
|
||||
"week": result.week,
|
||||
"trajectories_exported": result.trajectories_exported,
|
||||
"trajectories_accepted": result.trajectories_accepted,
|
||||
"examples_added": result.examples_added,
|
||||
"dataset_total": result.dataset_total,
|
||||
"train_status": result.train_status,
|
||||
"adapter_path": result.adapter_path,
|
||||
"model_name": result.model_name,
|
||||
"train_loss": result.train_loss,
|
||||
"duration_seconds": result.duration_seconds,
|
||||
"notes": result.notes,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" AutoLoRA Retrain — Cycle {result.iteration}")
|
||||
print(f" Week: {result.week}")
|
||||
print(f"{'='*60}")
|
||||
print(f" Trajectories: {result.trajectories_exported} exported, {result.trajectories_accepted} accepted")
|
||||
print(f" Dataset: +{result.examples_added} examples ({result.dataset_total} total)")
|
||||
print(f" Fine-tune: {result.train_status}")
|
||||
if result.train_loss is not None:
|
||||
print(f" Train loss: {result.train_loss:.4f}")
|
||||
if result.model_name:
|
||||
print(f" New model: {result.model_name}")
|
||||
if result.adapter_path:
|
||||
print(f" Adapter: {result.adapter_path}")
|
||||
print(f" Duration: {result.duration_seconds:.1f}s")
|
||||
print(f" Notes: {result.notes}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="AutoLoRA continuous improvement loop — sovereignty engine for Timmy"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weeks-ago",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Which week to process: 0=current (partial), 1=last week (default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-model",
|
||||
default="hermes4-14b",
|
||||
help="Ollama base model name (default: hermes4-14b)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Export and filter trajectories but skip actual fine-tuning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
dest="as_json",
|
||||
help="Output result as JSON",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
orchestrator = RetrainOrchestrator(
|
||||
base_model=args.base_model,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
result = orchestrator.run(weeks_ago=args.weeks_ago)
|
||||
_print_result(result, as_json=args.as_json)
|
||||
|
||||
# Exit 0 even on skipped/failed training — the loop must continue
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
180
timmy_automations/retrain/training_dataset.py
Normal file
180
timmy_automations/retrain/training_dataset.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""Training dataset manager — appends filtered trajectories to a JSONL training file.
|
||||
|
||||
Maintains a growing dataset of high-quality conversation examples in the
|
||||
chat-format expected by mlx-lm / HuggingFace fine-tuning pipelines.
|
||||
|
||||
Output format (one JSON object per line):
|
||||
{"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from timmy_automations.retrain.quality_filter import QualityResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_DATASET_PATH = ".loop/retrain/training_data.jsonl"
|
||||
_DEFAULT_INDEX_PATH = ".loop/retrain/dataset_index.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppendResult:
|
||||
"""Result of appending trajectories to the training dataset."""
|
||||
|
||||
new_examples: int
|
||||
total_examples: int
|
||||
dataset_path: str
|
||||
week_tag: str
|
||||
|
||||
|
||||
class TrainingDataset:
|
||||
"""Manages the LoRA training dataset file.
|
||||
|
||||
Each entry is a chat-format example:
|
||||
{"messages": [...], "week": "2026-W12", "quality": "high", "added_at": "..."}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path: str | Path | None = None,
|
||||
index_path: str | Path | None = None,
|
||||
repo_root: str | Path | None = None,
|
||||
):
|
||||
if repo_root is None:
|
||||
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
self._repo_root = Path(repo_root)
|
||||
|
||||
self._dataset_path = self._repo_root / (
|
||||
dataset_path or _DEFAULT_DATASET_PATH
|
||||
)
|
||||
self._index_path = self._repo_root / (
|
||||
index_path or _DEFAULT_INDEX_PATH
|
||||
)
|
||||
|
||||
self._dataset_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def dataset_path(self) -> Path:
|
||||
return self._dataset_path
|
||||
|
||||
def count(self) -> int:
|
||||
"""Return the number of examples currently in the dataset."""
|
||||
if not self._dataset_path.exists():
|
||||
return 0
|
||||
count = 0
|
||||
with open(self._dataset_path) as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def append(
|
||||
self, quality_results: list[QualityResult], week_tag: str
|
||||
) -> AppendResult:
|
||||
"""Append high-quality trajectories to the training dataset.
|
||||
|
||||
Deduplicates by (week_tag, session_date, started_at) so re-running
|
||||
the export for the same week is idempotent.
|
||||
|
||||
Args:
|
||||
quality_results: Filtered, trainable quality results.
|
||||
week_tag: ISO week string e.g. "2026-W12".
|
||||
|
||||
Returns:
|
||||
AppendResult with counts.
|
||||
"""
|
||||
existing_keys = self._load_existing_keys()
|
||||
new_count = 0
|
||||
added_at = datetime.now(tz=UTC).isoformat()
|
||||
|
||||
with open(self._dataset_path, "a") as f:
|
||||
for result in quality_results:
|
||||
traj = result.trajectory
|
||||
dedup_key = (
|
||||
f"{week_tag}|{traj.session_date}|{traj.started_at}"
|
||||
)
|
||||
if dedup_key in existing_keys:
|
||||
logger.debug("Skipping duplicate trajectory: %s", dedup_key)
|
||||
continue
|
||||
|
||||
chat_messages = traj.to_chat_format()
|
||||
if len(chat_messages) < 2:
|
||||
logger.debug(
|
||||
"Skipping trajectory with %d chat messages (need ≥2)",
|
||||
len(chat_messages),
|
||||
)
|
||||
continue
|
||||
|
||||
record = {
|
||||
"messages": chat_messages,
|
||||
"week": week_tag,
|
||||
"quality": result.quality.value,
|
||||
"score": result.score,
|
||||
"session_date": traj.session_date,
|
||||
"started_at": traj.started_at,
|
||||
"tool_calls": traj.tool_call_count,
|
||||
"added_at": added_at,
|
||||
}
|
||||
f.write(json.dumps(record) + "\n")
|
||||
existing_keys.add(dedup_key)
|
||||
new_count += 1
|
||||
|
||||
total = self.count()
|
||||
self._update_index(week_tag, new_count, total)
|
||||
logger.info(
|
||||
"Dataset: appended %d new examples (total=%d)", new_count, total
|
||||
)
|
||||
|
||||
return AppendResult(
|
||||
new_examples=new_count,
|
||||
total_examples=total,
|
||||
dataset_path=str(self._dataset_path),
|
||||
week_tag=week_tag,
|
||||
)
|
||||
|
||||
def _load_existing_keys(self) -> set[str]:
|
||||
"""Load deduplication keys from the existing dataset."""
|
||||
keys: set[str] = set()
|
||||
if not self._dataset_path.exists():
|
||||
return keys
|
||||
with open(self._dataset_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
record = json.loads(line)
|
||||
week = record.get("week", "")
|
||||
session_date = record.get("session_date", "")
|
||||
started_at = record.get("started_at", "")
|
||||
keys.add(f"{week}|{session_date}|{started_at}")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return keys
|
||||
|
||||
def _update_index(self, week_tag: str, new_count: int, total: int) -> None:
|
||||
"""Update the dataset index JSON with latest run metadata."""
|
||||
index: dict = {}
|
||||
if self._index_path.exists():
|
||||
try:
|
||||
index = json.loads(self._index_path.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
index = {}
|
||||
|
||||
index.setdefault("weeks", {})
|
||||
index["weeks"][week_tag] = {
|
||||
"examples_added": new_count,
|
||||
"updated_at": datetime.now(tz=UTC).isoformat(),
|
||||
}
|
||||
index["total_examples"] = total
|
||||
index["last_updated"] = datetime.now(tz=UTC).isoformat()
|
||||
|
||||
self._index_path.write_text(json.dumps(index, indent=2))
|
||||
183
timmy_automations/retrain/training_log.py
Normal file
183
timmy_automations/retrain/training_log.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Training log — records each fine-tune cycle with metrics and skill deltas.
|
||||
|
||||
Writes to .loop/retrain/training_log.jsonl (one entry per cycle) and
|
||||
maintains a human-readable .loop/retrain/training_log.md summary.
|
||||
|
||||
Each log entry captures:
|
||||
- Iteration count
|
||||
- Week processed
|
||||
- Quality filter stats
|
||||
- Examples added to dataset
|
||||
- LoRA train result (loss, duration, adapter path)
|
||||
- Skill accuracy deltas (from smoke tests)
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_LOG_PATH = ".loop/retrain/training_log.jsonl"
|
||||
_DEFAULT_SUMMARY_PATH = ".loop/retrain/training_log.md"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CycleMetrics:
|
||||
"""Metrics for a single retrain cycle."""
|
||||
|
||||
iteration: int
|
||||
week: str
|
||||
ran_at: str
|
||||
|
||||
# Quality filter
|
||||
trajectories_total: int = 0
|
||||
trajectories_high: int = 0
|
||||
trajectories_medium: int = 0
|
||||
trajectories_low: int = 0
|
||||
trajectories_accepted: int = 0
|
||||
|
||||
# Dataset
|
||||
examples_added: int = 0
|
||||
dataset_total: int = 0
|
||||
|
||||
# Training
|
||||
train_status: str = "skipped"
|
||||
train_loss: float | None = None
|
||||
train_duration_seconds: float = 0.0
|
||||
adapter_path: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
# Skill accuracy (optional, from smoke tests)
|
||||
skill_accuracy: dict[str, float] = field(default_factory=dict)
|
||||
skill_delta: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# Human-readable summary
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class TrainingLog:
|
||||
"""Persistent log of all retrain cycles."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_path: str | Path | None = None,
|
||||
summary_path: str | Path | None = None,
|
||||
repo_root: str | Path | None = None,
|
||||
):
|
||||
if repo_root is None:
|
||||
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
self._repo_root = Path(repo_root)
|
||||
|
||||
self._log_path = self._repo_root / (log_path or _DEFAULT_LOG_PATH)
|
||||
self._summary_path = self._repo_root / (summary_path or _DEFAULT_SUMMARY_PATH)
|
||||
self._log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def log_path(self) -> Path:
|
||||
return self._log_path
|
||||
|
||||
def next_iteration(self) -> int:
|
||||
"""Return the next iteration number (1-indexed)."""
|
||||
entries = self.load_all()
|
||||
if not entries:
|
||||
return 1
|
||||
return max(e.get("iteration", 0) for e in entries) + 1
|
||||
|
||||
def record(self, metrics: CycleMetrics) -> None:
|
||||
"""Append a cycle metrics record to the log."""
|
||||
entry = asdict(metrics)
|
||||
with open(self._log_path, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
self._update_summary(metrics)
|
||||
logger.info(
|
||||
"Training log: iteration=%d week=%s status=%s examples_added=%d",
|
||||
metrics.iteration,
|
||||
metrics.week,
|
||||
metrics.train_status,
|
||||
metrics.examples_added,
|
||||
)
|
||||
|
||||
def load_all(self) -> list[dict[str, Any]]:
|
||||
"""Load all cycle records from the log."""
|
||||
if not self._log_path.exists():
|
||||
return []
|
||||
entries: list[dict[str, Any]] = []
|
||||
with open(self._log_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed log entry")
|
||||
return entries
|
||||
|
||||
def latest(self) -> dict[str, Any] | None:
|
||||
"""Return the most recent cycle record."""
|
||||
entries = self.load_all()
|
||||
return entries[-1] if entries else None
|
||||
|
||||
def _update_summary(self, metrics: CycleMetrics) -> None:
|
||||
"""Rewrite the markdown summary with all cycles."""
|
||||
all_entries = self.load_all()
|
||||
|
||||
lines = [
|
||||
"# AutoLoRA Training Log\n",
|
||||
f"*Updated: {datetime.now(tz=UTC).isoformat()}*\n",
|
||||
f"*Total iterations: {len(all_entries)}*\n",
|
||||
"",
|
||||
"## Cycles\n",
|
||||
"| # | Week | Status | Loss | Examples | Duration |",
|
||||
"|---|------|--------|------|----------|----------|",
|
||||
]
|
||||
|
||||
for entry in reversed(all_entries[-20:]): # Last 20 cycles
|
||||
loss = f"{entry.get('train_loss', 0.0) or 0.0:.4f}" if entry.get("train_loss") else "—"
|
||||
lines.append(
|
||||
f"| {entry.get('iteration', '?')} "
|
||||
f"| {entry.get('week', '?')} "
|
||||
f"| {entry.get('train_status', '?')} "
|
||||
f"| {loss} "
|
||||
f"| +{entry.get('examples_added', 0)} ({entry.get('dataset_total', 0)} total) "
|
||||
f"| {entry.get('train_duration_seconds', 0.0):.0f}s |"
|
||||
)
|
||||
|
||||
lines.append("")
|
||||
lines.append("## Skill Accuracy Over Time\n")
|
||||
|
||||
# Collect all unique skills
|
||||
all_skills: set[str] = set()
|
||||
for entry in all_entries:
|
||||
all_skills.update(entry.get("skill_accuracy", {}).keys())
|
||||
|
||||
if all_skills:
|
||||
skill_header = "| # | Week | " + " | ".join(sorted(all_skills)) + " |"
|
||||
skill_sep = "|---|------|" + "|".join("---" for _ in all_skills) + "|"
|
||||
lines.extend([skill_header, skill_sep])
|
||||
for entry in reversed(all_entries[-10:]):
|
||||
acc = entry.get("skill_accuracy", {})
|
||||
row = f"| {entry.get('iteration', '?')} | {entry.get('week', '?')} | "
|
||||
row += " | ".join(
|
||||
f"{acc.get(s, 0.0):.0%}" if s in acc else "—"
|
||||
for s in sorted(all_skills)
|
||||
)
|
||||
row += " |"
|
||||
lines.append(row)
|
||||
else:
|
||||
lines.append("*No skill accuracy data yet — run smoke tests after fine-tuning.*")
|
||||
|
||||
lines.append("")
|
||||
if metrics.notes:
|
||||
lines.append(f"## Latest Notes\n\n{metrics.notes}\n")
|
||||
|
||||
self._summary_path.write_text("\n".join(lines))
|
||||
255
timmy_automations/retrain/trajectory_exporter.py
Normal file
255
timmy_automations/retrain/trajectory_exporter.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Trajectory exporter — reads session JSONL logs and extracts conversation trajectories.
|
||||
|
||||
A trajectory is a coherent sequence of messages + tool calls that form
|
||||
a single task attempt. Each trajectory becomes one training example.
|
||||
|
||||
Refs: #1105
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_LOGS_DIR_DEFAULT = "logs"
|
||||
_SESSION_GLOB = "session_*.jsonl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trajectory:
|
||||
"""A single conversation trajectory extracted from session logs."""
|
||||
|
||||
session_date: str
|
||||
started_at: str
|
||||
ended_at: str
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
errors: list[dict[str, Any]] = field(default_factory=list)
|
||||
decisions: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def tool_call_count(self) -> int:
|
||||
return len(self.tool_calls)
|
||||
|
||||
@property
|
||||
def error_count(self) -> int:
|
||||
return len(self.errors)
|
||||
|
||||
@property
|
||||
def has_successful_tool_call(self) -> bool:
|
||||
"""True if any tool call succeeded (no error entry follows it)."""
|
||||
return self.tool_call_count > 0 and self.error_count == 0
|
||||
|
||||
@property
|
||||
def is_multi_step(self) -> bool:
|
||||
"""True if this trajectory involved multiple turns with tool use."""
|
||||
return self.message_count >= 2 and self.tool_call_count >= 1
|
||||
|
||||
def to_chat_format(self) -> list[dict[str, str]]:
|
||||
"""Convert trajectory to chat-format messages for training.
|
||||
|
||||
Interleaves messages and tool-call results as assistant/tool turns.
|
||||
"""
|
||||
chat: list[dict[str, str]] = []
|
||||
# Merge all entries by timestamp and emit in order
|
||||
all_entries = sorted(
|
||||
self.messages + self.tool_calls + self.decisions,
|
||||
key=lambda e: e.get("timestamp", ""),
|
||||
)
|
||||
for entry in all_entries:
|
||||
etype = entry.get("type")
|
||||
if etype == "message":
|
||||
role = "user" if entry.get("role") == "user" else "assistant"
|
||||
content = entry.get("content", "")
|
||||
if content:
|
||||
chat.append({"role": role, "content": content})
|
||||
elif etype == "tool_call":
|
||||
tool = entry.get("tool", "unknown")
|
||||
result = entry.get("result", "")
|
||||
chat.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[tool:{tool}] {result}",
|
||||
}
|
||||
)
|
||||
elif etype == "decision":
|
||||
decision = entry.get("decision", "")
|
||||
if decision:
|
||||
chat.append({"role": "assistant", "content": f"[decided] {decision}"})
|
||||
return chat
|
||||
|
||||
|
||||
class TrajectoryExporter:
|
||||
"""Reads session JSONL logs and yields Trajectory objects for a date range."""
|
||||
|
||||
def __init__(self, logs_dir: str | Path | None = None, repo_root: str | Path | None = None):
|
||||
if repo_root is None:
|
||||
repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
self._repo_root = Path(repo_root)
|
||||
|
||||
if logs_dir is None:
|
||||
self._logs_dir = self._repo_root / _LOGS_DIR_DEFAULT
|
||||
else:
|
||||
self._logs_dir = Path(logs_dir)
|
||||
|
||||
def export_week(self, weeks_ago: int = 0) -> list[Trajectory]:
|
||||
"""Export all trajectories from the specified week.
|
||||
|
||||
Args:
|
||||
weeks_ago: 0 = current week, 1 = last week, etc.
|
||||
|
||||
Returns:
|
||||
List of Trajectory objects extracted from session logs.
|
||||
"""
|
||||
now = datetime.now(tz=UTC)
|
||||
# Week boundaries: Mon–Sun
|
||||
days_since_monday = now.weekday()
|
||||
week_start = (now - timedelta(days=days_since_monday + 7 * weeks_ago)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
week_end = week_start + timedelta(days=7)
|
||||
|
||||
logger.info(
|
||||
"Exporting trajectories for week %s–%s",
|
||||
week_start.date().isoformat(),
|
||||
week_end.date().isoformat(),
|
||||
)
|
||||
|
||||
trajectories: list[Trajectory] = []
|
||||
log_files = sorted(self._logs_dir.glob(_SESSION_GLOB))
|
||||
|
||||
for log_file in log_files:
|
||||
# Parse date from filename: session_YYYY-MM-DD.jsonl
|
||||
try:
|
||||
date_str = log_file.stem.removeprefix("session_")
|
||||
file_date = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=UTC)
|
||||
except ValueError:
|
||||
logger.debug("Skipping non-date session file: %s", log_file.name)
|
||||
continue
|
||||
|
||||
if not (week_start <= file_date < week_end):
|
||||
continue
|
||||
|
||||
file_trajectories = self._extract_from_file(log_file)
|
||||
trajectories.extend(file_trajectories)
|
||||
logger.info(
|
||||
"Extracted %d trajectories from %s", len(file_trajectories), log_file.name
|
||||
)
|
||||
|
||||
logger.info("Total trajectories exported: %d", len(trajectories))
|
||||
return trajectories
|
||||
|
||||
def _extract_from_file(self, log_file: Path) -> list[Trajectory]:
|
||||
"""Parse a single session JSONL file into trajectories.
|
||||
|
||||
Groups entries into trajectories by finding natural conversation
|
||||
boundaries (gaps of inactivity or topic shifts in the message stream).
|
||||
"""
|
||||
entries: list[dict[str, Any]] = []
|
||||
try:
|
||||
with open(log_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSON line in %s", log_file.name)
|
||||
except OSError as exc:
|
||||
logger.warning("Could not read %s: %s", log_file, exc)
|
||||
return []
|
||||
|
||||
if not entries:
|
||||
return []
|
||||
|
||||
date_str = log_file.stem.removeprefix("session_")
|
||||
return self._segment_trajectories(entries, date_str)
|
||||
|
||||
def _segment_trajectories(
|
||||
self, entries: list[dict[str, Any]], session_date: str
|
||||
) -> list[Trajectory]:
|
||||
"""Split a flat list of session entries into discrete trajectories.
|
||||
|
||||
Segmentation rule: start a new trajectory when:
|
||||
- A user message follows a Timmy message (new conversation turn)
|
||||
- More than 5 minutes have elapsed between entries
|
||||
|
||||
This produces training examples that are coherent task attempts.
|
||||
"""
|
||||
if not entries:
|
||||
return []
|
||||
|
||||
trajectories: list[Trajectory] = []
|
||||
current_entries: list[dict[str, Any]] = []
|
||||
prev_ts: datetime | None = None
|
||||
_SEGMENT_GAP_MINUTES = 5
|
||||
|
||||
def _flush() -> None:
|
||||
if current_entries:
|
||||
traj = _build_trajectory(current_entries, session_date)
|
||||
if traj.message_count > 0:
|
||||
trajectories.append(traj)
|
||||
|
||||
for entry in entries:
|
||||
ts_raw = entry.get("timestamp", "")
|
||||
try:
|
||||
ts = datetime.fromisoformat(ts_raw.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
ts = None
|
||||
|
||||
# Time-gap segmentation
|
||||
if ts and prev_ts and (ts - prev_ts).total_seconds() > _SEGMENT_GAP_MINUTES * 60:
|
||||
_flush()
|
||||
current_entries = []
|
||||
|
||||
# New-turn segmentation: user message after assistant turn
|
||||
etype = entry.get("type")
|
||||
erole = entry.get("role")
|
||||
if etype == "message" and erole == "user" and current_entries:
|
||||
# Check if previous non-error entry was a Timmy message
|
||||
for prev in reversed(current_entries):
|
||||
if prev.get("type") == "message":
|
||||
if prev.get("role") == "timmy":
|
||||
_flush()
|
||||
current_entries = []
|
||||
break
|
||||
|
||||
current_entries.append(entry)
|
||||
if ts:
|
||||
prev_ts = ts
|
||||
|
||||
_flush()
|
||||
return trajectories
|
||||
|
||||
|
||||
def _build_trajectory(entries: list[dict[str, Any]], session_date: str) -> Trajectory:
|
||||
"""Build a Trajectory from a flat list of entries."""
|
||||
messages = [e for e in entries if e.get("type") == "message"]
|
||||
tool_calls = [e for e in entries if e.get("type") == "tool_call"]
|
||||
errors = [e for e in entries if e.get("type") == "error"]
|
||||
decisions = [e for e in entries if e.get("type") == "decision"]
|
||||
|
||||
timestamps = [e.get("timestamp", "") for e in entries if e.get("timestamp")]
|
||||
started_at = min(timestamps) if timestamps else ""
|
||||
ended_at = max(timestamps) if timestamps else ""
|
||||
|
||||
return Trajectory(
|
||||
session_date=session_date,
|
||||
started_at=started_at,
|
||||
ended_at=ended_at,
|
||||
messages=messages,
|
||||
tool_calls=tool_calls,
|
||||
errors=errors,
|
||||
decisions=decisions,
|
||||
)
|
||||
Reference in New Issue
Block a user