forked from Rockachopa/Timmy-time-dashboard
181 lines
5.9 KiB
Python
181 lines
5.9 KiB
Python
"""Training dataset manager — appends filtered trajectories to a JSONL training file.
|
|
|
|
Maintains a growing dataset of high-quality conversation examples in the
|
|
chat-format expected by mlx-lm / HuggingFace fine-tuning pipelines.
|
|
|
|
Output format (one JSON object per line):
|
|
{"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
|
|
|
|
Refs: #1105
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
|
|
from timmy_automations.retrain.quality_filter import QualityResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEFAULT_DATASET_PATH = ".loop/retrain/training_data.jsonl"
|
|
_DEFAULT_INDEX_PATH = ".loop/retrain/dataset_index.json"
|
|
|
|
|
|
@dataclass
|
|
class AppendResult:
|
|
"""Result of appending trajectories to the training dataset."""
|
|
|
|
new_examples: int
|
|
total_examples: int
|
|
dataset_path: str
|
|
week_tag: str
|
|
|
|
|
|
class TrainingDataset:
|
|
"""Manages the LoRA training dataset file.
|
|
|
|
Each entry is a chat-format example:
|
|
{"messages": [...], "week": "2026-W12", "quality": "high", "added_at": "..."}
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_path: str | Path | None = None,
|
|
index_path: str | Path | None = None,
|
|
repo_root: str | Path | None = None,
|
|
):
|
|
if repo_root is None:
|
|
repo_root = Path(__file__).resolve().parent.parent.parent
|
|
self._repo_root = Path(repo_root)
|
|
|
|
self._dataset_path = self._repo_root / (
|
|
dataset_path or _DEFAULT_DATASET_PATH
|
|
)
|
|
self._index_path = self._repo_root / (
|
|
index_path or _DEFAULT_INDEX_PATH
|
|
)
|
|
|
|
self._dataset_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
@property
|
|
def dataset_path(self) -> Path:
|
|
return self._dataset_path
|
|
|
|
def count(self) -> int:
|
|
"""Return the number of examples currently in the dataset."""
|
|
if not self._dataset_path.exists():
|
|
return 0
|
|
count = 0
|
|
with open(self._dataset_path) as f:
|
|
for line in f:
|
|
if line.strip():
|
|
count += 1
|
|
return count
|
|
|
|
def append(
|
|
self, quality_results: list[QualityResult], week_tag: str
|
|
) -> AppendResult:
|
|
"""Append high-quality trajectories to the training dataset.
|
|
|
|
Deduplicates by (week_tag, session_date, started_at) so re-running
|
|
the export for the same week is idempotent.
|
|
|
|
Args:
|
|
quality_results: Filtered, trainable quality results.
|
|
week_tag: ISO week string e.g. "2026-W12".
|
|
|
|
Returns:
|
|
AppendResult with counts.
|
|
"""
|
|
existing_keys = self._load_existing_keys()
|
|
new_count = 0
|
|
added_at = datetime.now(tz=UTC).isoformat()
|
|
|
|
with open(self._dataset_path, "a") as f:
|
|
for result in quality_results:
|
|
traj = result.trajectory
|
|
dedup_key = (
|
|
f"{week_tag}|{traj.session_date}|{traj.started_at}"
|
|
)
|
|
if dedup_key in existing_keys:
|
|
logger.debug("Skipping duplicate trajectory: %s", dedup_key)
|
|
continue
|
|
|
|
chat_messages = traj.to_chat_format()
|
|
if len(chat_messages) < 2:
|
|
logger.debug(
|
|
"Skipping trajectory with %d chat messages (need ≥2)",
|
|
len(chat_messages),
|
|
)
|
|
continue
|
|
|
|
record = {
|
|
"messages": chat_messages,
|
|
"week": week_tag,
|
|
"quality": result.quality.value,
|
|
"score": result.score,
|
|
"session_date": traj.session_date,
|
|
"started_at": traj.started_at,
|
|
"tool_calls": traj.tool_call_count,
|
|
"added_at": added_at,
|
|
}
|
|
f.write(json.dumps(record) + "\n")
|
|
existing_keys.add(dedup_key)
|
|
new_count += 1
|
|
|
|
total = self.count()
|
|
self._update_index(week_tag, new_count, total)
|
|
logger.info(
|
|
"Dataset: appended %d new examples (total=%d)", new_count, total
|
|
)
|
|
|
|
return AppendResult(
|
|
new_examples=new_count,
|
|
total_examples=total,
|
|
dataset_path=str(self._dataset_path),
|
|
week_tag=week_tag,
|
|
)
|
|
|
|
def _load_existing_keys(self) -> set[str]:
|
|
"""Load deduplication keys from the existing dataset."""
|
|
keys: set[str] = set()
|
|
if not self._dataset_path.exists():
|
|
return keys
|
|
with open(self._dataset_path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
record = json.loads(line)
|
|
week = record.get("week", "")
|
|
session_date = record.get("session_date", "")
|
|
started_at = record.get("started_at", "")
|
|
keys.add(f"{week}|{session_date}|{started_at}")
|
|
except json.JSONDecodeError:
|
|
continue
|
|
return keys
|
|
|
|
def _update_index(self, week_tag: str, new_count: int, total: int) -> None:
|
|
"""Update the dataset index JSON with latest run metadata."""
|
|
index: dict = {}
|
|
if self._index_path.exists():
|
|
try:
|
|
index = json.loads(self._index_path.read_text())
|
|
except (json.JSONDecodeError, OSError):
|
|
index = {}
|
|
|
|
index.setdefault("weeks", {})
|
|
index["weeks"][week_tag] = {
|
|
"examples_added": new_count,
|
|
"updated_at": datetime.now(tz=UTC).isoformat(),
|
|
}
|
|
index["total_examples"] = total
|
|
index["last_updated"] = datetime.now(tz=UTC).isoformat()
|
|
|
|
self._index_path.write_text(json.dumps(index, indent=2))
|