Files
timmy-config/training/training_pair_provenance.py

234 lines
8.8 KiB
Python

#!/usr/bin/env python3
"""Training pair provenance tracking.
Every training pair in the pipeline must carry provenance metadata:
- source_session_id: origin session identifier
- source: trajectory|curated|backfill
- model: model that produced the assistant turn
- timestamp: ISO 8601 when the pair was captured
- excluded: bool + reason if filtered out during quality screening
- approved: bool (default True for accepted pairs)
Usage:
from training_pair_provenance import ProvenanceTracker
tracker = ProvenanceTracker()
for pair in pairs:
pair = tracker.annotate(pair, source="trajectory", model="hermes4:14b")
tracker.report()
"""
import json
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
REQUIRED_FIELDS = ("source_session_id", "source", "model", "timestamp", "approved")
@dataclass
class ProvenanceMeta:
source_session_id: str = ""
source: str = "" # trajectory | curated | backfill
model: str = ""
timestamp: str = ""
excluded: bool = False
exclusion_reason: str = ""
approved: bool = True
class ProvenanceTracker:
"""Annotate and validate provenance metadata on training pairs."""
def __init__(self):
self._stats = {
"total": 0,
"approved": 0,
"excluded": 0,
"missing_provenance": 0,
"by_source": {},
"by_model": {},
}
# ── annotation ─────────────────────────────────────────────
def annotate(
self,
pair: dict,
*,
source: str,
model: str,
session_id: str = "",
timestamp: str = "",
) -> dict:
"""Attach provenance metadata to a training pair dict."""
meta = pair.get("provenance", {})
meta.setdefault("source", source)
meta.setdefault("model", model)
if session_id:
meta.setdefault("source_session_id", session_id)
if timestamp:
meta.setdefault("timestamp", timestamp)
if "timestamp" not in meta:
meta["timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
if "approved" not in meta:
meta["approved"] = True
pair["provenance"] = meta
self._track(pair)
return pair
def exclude(self, pair: dict, reason: str) -> dict:
"""Mark pair as excluded with a reason."""
meta = pair.get("provenance", {})
meta["excluded"] = True
meta["exclusion_reason"] = reason
meta["approved"] = False
pair["provenance"] = meta
return pair
# ── backfill ───────────────────────────────────────────────
def backfill(
self,
pair: dict,
*,
source: str,
model: str,
) -> dict:
"""Add provenance to a pair that has none (historical data)."""
meta = pair.get("provenance")
if meta and meta.get("source") and meta.get("model"):
return pair # already has provenance
return self.annotate(pair, source=source, model=model, session_id="backfill")
def backfill_file(self, path: Path, *, source: str, model: str) -> int:
"""Backfill provenance on an entire JSONL file. Returns count updated."""
pairs = []
count = 0
with open(path) as f:
for line in f:
line = line.strip()
if line:
pair = json.loads(line)
pair = self.backfill(pair, source=source, model=model)
pairs.append(pair)
count += 1
with open(path, "w") as f:
for pair in pairs:
f.write(json.dumps(pair) + "\n")
return count
# ── validation ─────────────────────────────────────────────
def validate(self, pair: dict) -> list[str]:
"""Return list of validation errors for a pair. Empty = valid."""
errors = []
meta = pair.get("provenance")
if not meta:
return ["missing provenance metadata"]
for field_name in REQUIRED_FIELDS:
if field_name not in meta:
errors.append(f"provenance missing field: {field_name}")
if meta.get("excluded") and not meta.get("exclusion_reason"):
errors.append("excluded pair missing exclusion_reason")
return errors
def validate_file(self, path: Path) -> dict:
"""Validate all pairs in a JSONL file. Returns {valid, invalid, errors}."""
results = {"valid": 0, "invalid": 0, "errors": []}
with open(path) as f:
for i, line in enumerate(f, 1):
line = line.strip()
if line:
pair = json.loads(line)
errs = self.validate(pair)
if errs:
results["invalid"] += 1
results["errors"].append({"line": i, "errors": errs})
else:
results["valid"] += 1
return results
# ── reporting ──────────────────────────────────────────────
def _track(self, pair: dict):
meta = pair.get("provenance", {})
self._stats["total"] += 1
if meta.get("approved"):
self._stats["approved"] += 1
if meta.get("excluded"):
self._stats["excluded"] += 1
source = meta.get("source", "unknown")
self._stats["by_source"][source] = self._stats["by_source"].get(source, 0) + 1
model = meta.get("model", "unknown")
self._stats["by_model"][model] = self._stats["by_model"].get(model, 0) + 1
def report(self) -> dict:
"""Return provenance statistics dict."""
return dict(self._stats)
def report_text(self) -> str:
"""Return human-readable provenance report."""
s = self._stats
lines = [
"Provenance Report",
"=" * 40,
f" Total pairs: {s['total']}",
f" Approved: {s['approved']}",
f" Excluded: {s['excluded']}",
"",
" By source:",
]
for source, count in sorted(s["by_source"].items()):
lines.append(f" {source:20s} {count}")
lines.append(" By model:")
for model, count in sorted(s["by_model"].items()):
lines.append(f" {model:20s} {count}")
return "\n".join(lines)
def load_jsonl(path: Path) -> list[dict]:
entries = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
entries.append(json.loads(line))
return entries
def write_jsonl(path: Path, pairs: list[dict]):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
for pair in pairs:
f.write(json.dumps(pair) + "\n")
def provenance_dashboard(path: Path) -> str:
"""Generate a provenance dashboard from a JSONL file."""
pairs = load_jsonl(path)
tracker = ProvenanceTracker()
for pair in pairs:
tracker._track(pair)
report = tracker.report()
lines = [
"╔══════════════════════════════════════╗",
"║ Training Provenance Dashboard ║",
"╠══════════════════════════════════════╣",
f"║ Total pairs: {report['total']:>12}",
f"║ Approved: {report['approved']:>12}",
f"║ Excluded: {report['excluded']:>12}",
f"║ Provenance coverage:{'>0%' if report['total'] == 0 else f'{(report[\"approved\"] + report[\"excluded\"])*100 // max(report[\"total\"], 1)}%':>12s} ║",
"╠══════════════════════════════════════╣",
"║ By Source ║",
]
for source, count in sorted(report["by_source"].items()):
lines.append(f"{source:20s} {count:>8}")
lines.append("╠══════════════════════════════════════╣")
lines.append("║ By Model ║")
for model, count in sorted(report["by_model"].items()):
lines.append(f"{model:20s} {count:>8}")
lines.append("╚══════════════════════════════════════╝")
return "\n".join(lines)