234 lines
8.8 KiB
Python
234 lines
8.8 KiB
Python
#!/usr/bin/env python3
|
|
"""Training pair provenance tracking.
|
|
|
|
Every training pair in the pipeline must carry provenance metadata:
|
|
- source_session_id: origin session identifier
|
|
- source: trajectory|curated|backfill
|
|
- model: model that produced the assistant turn
|
|
- timestamp: ISO 8601 when the pair was captured
|
|
- excluded: bool + reason if filtered out during quality screening
|
|
- approved: bool (default True for accepted pairs)
|
|
|
|
Usage:
|
|
from training_pair_provenance import ProvenanceTracker
|
|
|
|
tracker = ProvenanceTracker()
|
|
for pair in pairs:
|
|
pair = tracker.annotate(pair, source="trajectory", model="hermes4:14b")
|
|
tracker.report()
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
from dataclasses import asdict, dataclass, field
|
|
from pathlib import Path
|
|
|
|
|
|
REQUIRED_FIELDS = ("source_session_id", "source", "model", "timestamp", "approved")
|
|
|
|
|
|
@dataclass
|
|
class ProvenanceMeta:
|
|
source_session_id: str = ""
|
|
source: str = "" # trajectory | curated | backfill
|
|
model: str = ""
|
|
timestamp: str = ""
|
|
excluded: bool = False
|
|
exclusion_reason: str = ""
|
|
approved: bool = True
|
|
|
|
|
|
class ProvenanceTracker:
|
|
"""Annotate and validate provenance metadata on training pairs."""
|
|
|
|
def __init__(self):
|
|
self._stats = {
|
|
"total": 0,
|
|
"approved": 0,
|
|
"excluded": 0,
|
|
"missing_provenance": 0,
|
|
"by_source": {},
|
|
"by_model": {},
|
|
}
|
|
|
|
# ── annotation ─────────────────────────────────────────────
|
|
|
|
def annotate(
|
|
self,
|
|
pair: dict,
|
|
*,
|
|
source: str,
|
|
model: str,
|
|
session_id: str = "",
|
|
timestamp: str = "",
|
|
) -> dict:
|
|
"""Attach provenance metadata to a training pair dict."""
|
|
meta = pair.get("provenance", {})
|
|
meta.setdefault("source", source)
|
|
meta.setdefault("model", model)
|
|
if session_id:
|
|
meta.setdefault("source_session_id", session_id)
|
|
if timestamp:
|
|
meta.setdefault("timestamp", timestamp)
|
|
if "timestamp" not in meta:
|
|
meta["timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
|
if "approved" not in meta:
|
|
meta["approved"] = True
|
|
pair["provenance"] = meta
|
|
self._track(pair)
|
|
return pair
|
|
|
|
def exclude(self, pair: dict, reason: str) -> dict:
|
|
"""Mark pair as excluded with a reason."""
|
|
meta = pair.get("provenance", {})
|
|
meta["excluded"] = True
|
|
meta["exclusion_reason"] = reason
|
|
meta["approved"] = False
|
|
pair["provenance"] = meta
|
|
return pair
|
|
|
|
# ── backfill ───────────────────────────────────────────────
|
|
|
|
def backfill(
|
|
self,
|
|
pair: dict,
|
|
*,
|
|
source: str,
|
|
model: str,
|
|
) -> dict:
|
|
"""Add provenance to a pair that has none (historical data)."""
|
|
meta = pair.get("provenance")
|
|
if meta and meta.get("source") and meta.get("model"):
|
|
return pair # already has provenance
|
|
return self.annotate(pair, source=source, model=model, session_id="backfill")
|
|
|
|
def backfill_file(self, path: Path, *, source: str, model: str) -> int:
|
|
"""Backfill provenance on an entire JSONL file. Returns count updated."""
|
|
pairs = []
|
|
count = 0
|
|
with open(path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
pair = json.loads(line)
|
|
pair = self.backfill(pair, source=source, model=model)
|
|
pairs.append(pair)
|
|
count += 1
|
|
with open(path, "w") as f:
|
|
for pair in pairs:
|
|
f.write(json.dumps(pair) + "\n")
|
|
return count
|
|
|
|
# ── validation ─────────────────────────────────────────────
|
|
|
|
def validate(self, pair: dict) -> list[str]:
|
|
"""Return list of validation errors for a pair. Empty = valid."""
|
|
errors = []
|
|
meta = pair.get("provenance")
|
|
if not meta:
|
|
return ["missing provenance metadata"]
|
|
for field_name in REQUIRED_FIELDS:
|
|
if field_name not in meta:
|
|
errors.append(f"provenance missing field: {field_name}")
|
|
if meta.get("excluded") and not meta.get("exclusion_reason"):
|
|
errors.append("excluded pair missing exclusion_reason")
|
|
return errors
|
|
|
|
def validate_file(self, path: Path) -> dict:
|
|
"""Validate all pairs in a JSONL file. Returns {valid, invalid, errors}."""
|
|
results = {"valid": 0, "invalid": 0, "errors": []}
|
|
with open(path) as f:
|
|
for i, line in enumerate(f, 1):
|
|
line = line.strip()
|
|
if line:
|
|
pair = json.loads(line)
|
|
errs = self.validate(pair)
|
|
if errs:
|
|
results["invalid"] += 1
|
|
results["errors"].append({"line": i, "errors": errs})
|
|
else:
|
|
results["valid"] += 1
|
|
return results
|
|
|
|
# ── reporting ──────────────────────────────────────────────
|
|
|
|
def _track(self, pair: dict):
|
|
meta = pair.get("provenance", {})
|
|
self._stats["total"] += 1
|
|
if meta.get("approved"):
|
|
self._stats["approved"] += 1
|
|
if meta.get("excluded"):
|
|
self._stats["excluded"] += 1
|
|
source = meta.get("source", "unknown")
|
|
self._stats["by_source"][source] = self._stats["by_source"].get(source, 0) + 1
|
|
model = meta.get("model", "unknown")
|
|
self._stats["by_model"][model] = self._stats["by_model"].get(model, 0) + 1
|
|
|
|
def report(self) -> dict:
|
|
"""Return provenance statistics dict."""
|
|
return dict(self._stats)
|
|
|
|
def report_text(self) -> str:
|
|
"""Return human-readable provenance report."""
|
|
s = self._stats
|
|
lines = [
|
|
"Provenance Report",
|
|
"=" * 40,
|
|
f" Total pairs: {s['total']}",
|
|
f" Approved: {s['approved']}",
|
|
f" Excluded: {s['excluded']}",
|
|
"",
|
|
" By source:",
|
|
]
|
|
for source, count in sorted(s["by_source"].items()):
|
|
lines.append(f" {source:20s} {count}")
|
|
lines.append(" By model:")
|
|
for model, count in sorted(s["by_model"].items()):
|
|
lines.append(f" {model:20s} {count}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def load_jsonl(path: Path) -> list[dict]:
|
|
entries = []
|
|
with open(path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
entries.append(json.loads(line))
|
|
return entries
|
|
|
|
|
|
def write_jsonl(path: Path, pairs: list[dict]):
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(path, "w") as f:
|
|
for pair in pairs:
|
|
f.write(json.dumps(pair) + "\n")
|
|
|
|
|
|
def provenance_dashboard(path: Path) -> str:
|
|
"""Generate a provenance dashboard from a JSONL file."""
|
|
pairs = load_jsonl(path)
|
|
tracker = ProvenanceTracker()
|
|
for pair in pairs:
|
|
tracker._track(pair)
|
|
report = tracker.report()
|
|
lines = [
|
|
"╔══════════════════════════════════════╗",
|
|
"║ Training Provenance Dashboard ║",
|
|
"╠══════════════════════════════════════╣",
|
|
f"║ Total pairs: {report['total']:>12} ║",
|
|
f"║ Approved: {report['approved']:>12} ║",
|
|
f"║ Excluded: {report['excluded']:>12} ║",
|
|
f"║ Provenance coverage:{'>0%' if report['total'] == 0 else f'{(report[\"approved\"] + report[\"excluded\"])*100 // max(report[\"total\"], 1)}%':>12s} ║",
|
|
"╠══════════════════════════════════════╣",
|
|
"║ By Source ║",
|
|
]
|
|
for source, count in sorted(report["by_source"].items()):
|
|
lines.append(f"║ {source:20s} {count:>8} ║")
|
|
lines.append("╠══════════════════════════════════════╣")
|
|
lines.append("║ By Model ║")
|
|
for model, count in sorted(report["by_model"].items()):
|
|
lines.append(f"║ {model:20s} {count:>8} ║")
|
|
lines.append("╚══════════════════════════════════════╝")
|
|
return "\n".join(lines)
|