116 lines
3.8 KiB
Python
116 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Training Pair Provenance Tracking
|
|
|
|
Tracks the origin, model, and quality metadata for each training pair.
|
|
Integrates with ingest_trajectories.py and build_curated.py.
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from dataclasses import dataclass, asdict
|
|
from typing import Optional
|
|
|
|
|
|
@dataclass
|
|
class ProvenanceMetadata:
|
|
"""Metadata tracking the provenance of a training pair."""
|
|
source_session_id: str
|
|
source_type: str # "trajectory", "curated", "augmented"
|
|
model: str
|
|
timestamp: str
|
|
quality_score: Optional[float] = None
|
|
excluded: bool = False
|
|
exclusion_reason: Optional[str] = None
|
|
|
|
def to_dict(self) -> dict:
|
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
|
|
|
|
def add_provenance(pair, source_session_id, source_type, model, quality_score=None):
|
|
"""Add provenance metadata to a training pair."""
|
|
provenance = ProvenanceMetadata(
|
|
source_session_id=source_session_id,
|
|
source_type=source_type,
|
|
model=model,
|
|
timestamp=time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
quality_score=quality_score
|
|
)
|
|
if "provenance" not in pair:
|
|
pair["provenance"] = {}
|
|
pair["provenance"].update(provenance.to_dict())
|
|
return pair
|
|
|
|
|
|
def extract_provenance_from_trajectory(trajectory):
|
|
"""Extract provenance metadata from a trajectory file."""
|
|
return {
|
|
"source_session_id": trajectory.get("id", "unknown"),
|
|
"source_type": "trajectory",
|
|
"model": trajectory.get("model", "unknown"),
|
|
"timestamp": trajectory.get("started_at", time.strftime("%Y-%m-%dT%H:%M:%S"))
|
|
}
|
|
|
|
|
|
def validate_provenance(pair):
|
|
"""Validate that a pair has complete provenance metadata."""
|
|
errors = []
|
|
if "provenance" not in pair:
|
|
errors.append("Missing provenance metadata")
|
|
return False, errors
|
|
prov = pair["provenance"]
|
|
required = ["source_session_id", "source_type", "model", "timestamp"]
|
|
for field in required:
|
|
if field not in prov:
|
|
errors.append(f"Missing required field: {field}")
|
|
elif not prov[field]:
|
|
errors.append(f"Empty required field: {field}")
|
|
valid_types = {"trajectory", "curated", "augmented"}
|
|
if prov.get("source_type") not in valid_types:
|
|
errors.append(f"Invalid source_type: {prov.get('source_type')}")
|
|
return len(errors) == 0, errors
|
|
|
|
|
|
def get_provenance_stats(pairs):
|
|
"""Compute statistics about provenance coverage."""
|
|
stats = {
|
|
"total_pairs": len(pairs),
|
|
"with_provenance": 0,
|
|
"by_source_type": {},
|
|
"by_model": {},
|
|
"excluded": 0,
|
|
"coverage_pct": 0.0
|
|
}
|
|
for pair in pairs:
|
|
if "provenance" in pair:
|
|
stats["with_provenance"] += 1
|
|
prov = pair["provenance"]
|
|
st = prov.get("source_type", "unknown")
|
|
stats["by_source_type"][st] = stats["by_source_type"].get(st, 0) + 1
|
|
model = prov.get("model", "unknown")
|
|
stats["by_model"][model] = stats["by_model"].get(model, 0) + 1
|
|
if prov.get("excluded"):
|
|
stats["excluded"] += 1
|
|
if stats["total_pairs"] > 0:
|
|
stats["coverage_pct"] = round(stats["with_provenance"] / stats["total_pairs"] * 100, 1)
|
|
return stats
|
|
|
|
|
|
def print_provenance_report(stats):
|
|
"""Print a human-readable provenance report."""
|
|
print("Provenance Report")
|
|
print("=" * 50)
|
|
print(f"Total pairs: {stats['total_pairs']}")
|
|
print(f"With provenance: {stats['with_provenance']}")
|
|
print(f"Coverage: {stats['coverage_pct']}%")
|
|
print(f"Excluded: {stats['excluded']}")
|
|
print()
|
|
print("By source type:")
|
|
for st, count in sorted(stats["by_source_type"].items()):
|
|
print(f" {st}: {count}")
|
|
print()
|
|
print("By model:")
|
|
for model, count in sorted(stats["by_model"].items()):
|
|
print(f" {model}: {count}")
|