Files
timmy-config/training/training_pair_provenance.py

116 lines
3.8 KiB
Python

#!/usr/bin/env python3
"""
Training Pair Provenance Tracking
Tracks the origin, model, and quality metadata for each training pair.
Integrates with ingest_trajectories.py and build_curated.py.
"""
import json
import time
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional
@dataclass
class ProvenanceMetadata:
"""Metadata tracking the provenance of a training pair."""
source_session_id: str
source_type: str # "trajectory", "curated", "augmented"
model: str
timestamp: str
quality_score: Optional[float] = None
excluded: bool = False
exclusion_reason: Optional[str] = None
def to_dict(self) -> dict:
return {k: v for k, v in asdict(self).items() if v is not None}
def add_provenance(pair, source_session_id, source_type, model, quality_score=None):
"""Add provenance metadata to a training pair."""
provenance = ProvenanceMetadata(
source_session_id=source_session_id,
source_type=source_type,
model=model,
timestamp=time.strftime("%Y-%m-%dT%H:%M:%S"),
quality_score=quality_score
)
if "provenance" not in pair:
pair["provenance"] = {}
pair["provenance"].update(provenance.to_dict())
return pair
def extract_provenance_from_trajectory(trajectory):
"""Extract provenance metadata from a trajectory file."""
return {
"source_session_id": trajectory.get("id", "unknown"),
"source_type": "trajectory",
"model": trajectory.get("model", "unknown"),
"timestamp": trajectory.get("started_at", time.strftime("%Y-%m-%dT%H:%M:%S"))
}
def validate_provenance(pair):
"""Validate that a pair has complete provenance metadata."""
errors = []
if "provenance" not in pair:
errors.append("Missing provenance metadata")
return False, errors
prov = pair["provenance"]
required = ["source_session_id", "source_type", "model", "timestamp"]
for field in required:
if field not in prov:
errors.append(f"Missing required field: {field}")
elif not prov[field]:
errors.append(f"Empty required field: {field}")
valid_types = {"trajectory", "curated", "augmented"}
if prov.get("source_type") not in valid_types:
errors.append(f"Invalid source_type: {prov.get('source_type')}")
return len(errors) == 0, errors
def get_provenance_stats(pairs):
"""Compute statistics about provenance coverage."""
stats = {
"total_pairs": len(pairs),
"with_provenance": 0,
"by_source_type": {},
"by_model": {},
"excluded": 0,
"coverage_pct": 0.0
}
for pair in pairs:
if "provenance" in pair:
stats["with_provenance"] += 1
prov = pair["provenance"]
st = prov.get("source_type", "unknown")
stats["by_source_type"][st] = stats["by_source_type"].get(st, 0) + 1
model = prov.get("model", "unknown")
stats["by_model"][model] = stats["by_model"].get(model, 0) + 1
if prov.get("excluded"):
stats["excluded"] += 1
if stats["total_pairs"] > 0:
stats["coverage_pct"] = round(stats["with_provenance"] / stats["total_pairs"] * 100, 1)
return stats
def print_provenance_report(stats):
"""Print a human-readable provenance report."""
print("Provenance Report")
print("=" * 50)
print(f"Total pairs: {stats['total_pairs']}")
print(f"With provenance: {stats['with_provenance']}")
print(f"Coverage: {stats['coverage_pct']}%")
print(f"Excluded: {stats['excluded']}")
print()
print("By source type:")
for st, count in sorted(stats["by_source_type"].items()):
print(f" {st}: {count}")
print()
print("By model:")
for model, count in sorted(stats["by_model"].items()):
print(f" {model}: {count}")