"""Training dataset manager — appends filtered trajectories to a JSONL training file. Maintains a growing dataset of high-quality conversation examples in the chat-format expected by mlx-lm / HuggingFace fine-tuning pipelines. Output format (one JSON object per line): {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]} Refs: #1105 """ from __future__ import annotations import json import logging from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path from timmy_automations.retrain.quality_filter import QualityResult logger = logging.getLogger(__name__) _DEFAULT_DATASET_PATH = ".loop/retrain/training_data.jsonl" _DEFAULT_INDEX_PATH = ".loop/retrain/dataset_index.json" @dataclass class AppendResult: """Result of appending trajectories to the training dataset.""" new_examples: int total_examples: int dataset_path: str week_tag: str class TrainingDataset: """Manages the LoRA training dataset file. Each entry is a chat-format example: {"messages": [...], "week": "2026-W12", "quality": "high", "added_at": "..."} """ def __init__( self, dataset_path: str | Path | None = None, index_path: str | Path | None = None, repo_root: str | Path | None = None, ): if repo_root is None: repo_root = Path(__file__).resolve().parent.parent.parent self._repo_root = Path(repo_root) self._dataset_path = self._repo_root / ( dataset_path or _DEFAULT_DATASET_PATH ) self._index_path = self._repo_root / ( index_path or _DEFAULT_INDEX_PATH ) self._dataset_path.parent.mkdir(parents=True, exist_ok=True) @property def dataset_path(self) -> Path: return self._dataset_path def count(self) -> int: """Return the number of examples currently in the dataset.""" if not self._dataset_path.exists(): return 0 count = 0 with open(self._dataset_path) as f: for line in f: if line.strip(): count += 1 return count def append( self, quality_results: list[QualityResult], week_tag: str ) -> AppendResult: """Append high-quality trajectories to the training dataset. Deduplicates by (week_tag, session_date, started_at) so re-running the export for the same week is idempotent. Args: quality_results: Filtered, trainable quality results. week_tag: ISO week string e.g. "2026-W12". Returns: AppendResult with counts. """ existing_keys = self._load_existing_keys() new_count = 0 added_at = datetime.now(tz=UTC).isoformat() with open(self._dataset_path, "a") as f: for result in quality_results: traj = result.trajectory dedup_key = ( f"{week_tag}|{traj.session_date}|{traj.started_at}" ) if dedup_key in existing_keys: logger.debug("Skipping duplicate trajectory: %s", dedup_key) continue chat_messages = traj.to_chat_format() if len(chat_messages) < 2: logger.debug( "Skipping trajectory with %d chat messages (need ≥2)", len(chat_messages), ) continue record = { "messages": chat_messages, "week": week_tag, "quality": result.quality.value, "score": result.score, "session_date": traj.session_date, "started_at": traj.started_at, "tool_calls": traj.tool_call_count, "added_at": added_at, } f.write(json.dumps(record) + "\n") existing_keys.add(dedup_key) new_count += 1 total = self.count() self._update_index(week_tag, new_count, total) logger.info( "Dataset: appended %d new examples (total=%d)", new_count, total ) return AppendResult( new_examples=new_count, total_examples=total, dataset_path=str(self._dataset_path), week_tag=week_tag, ) def _load_existing_keys(self) -> set[str]: """Load deduplication keys from the existing dataset.""" keys: set[str] = set() if not self._dataset_path.exists(): return keys with open(self._dataset_path) as f: for line in f: line = line.strip() if not line: continue try: record = json.loads(line) week = record.get("week", "") session_date = record.get("session_date", "") started_at = record.get("started_at", "") keys.add(f"{week}|{session_date}|{started_at}") except json.JSONDecodeError: continue return keys def _update_index(self, week_tag: str, new_count: int, total: int) -> None: """Update the dataset index JSON with latest run metadata.""" index: dict = {} if self._index_path.exists(): try: index = json.loads(self._index_path.read_text()) except (json.JSONDecodeError, OSError): index = {} index.setdefault("weeks", {}) index["weeks"][week_tag] = { "examples_added": new_count, "updated_at": datetime.now(tz=UTC).isoformat(), } index["total_examples"] = total index["last_updated"] = datetime.now(tz=UTC).isoformat() self._index_path.write_text(json.dumps(index, indent=2))