Merge PR #760: training/training_pair_provenance.py (added)

This commit is contained in:
Merge Bot
2026-04-16 05:07:40 +00:00
parent a7ba856524
commit 3d62df6b15

View File

@@ -1,115 +1,397 @@
#!/usr/bin/env python3
"""
Training Pair Provenance Tracking
training_pair_provenance.py — Provenance tracking for training data pairs.
Tracks the origin, model, and quality metadata for each training pair.
Integrates with ingest_trajectories.py and build_curated.py.
Every training pair should carry metadata about where it came from:
- Which session/trajectory produced it
- Which model generated it
- When it was created
- What source type (curated, trajectory, augmentation)
This module provides utilities to:
1. Attach provenance metadata to training pairs
2. Validate that provenance exists
3. Generate provenance statistics/dashboards
4. Backfill provenance on existing pairs
Usage:
from training_pair_provenance import attach_provenance, validate_provenance, provenance_dashboard
# Attach provenance to a pair
pair = attach_provenance(pair, source="trajectory", session_id="abc123", model="hermes3:latest")
# Validate provenance on a dataset
report = validate_provenance("data/curated_dataset.jsonl")
# Generate dashboard
print(provenance_dashboard("data/merged_training_data.jsonl"))
"""
import json
import time
import hashlib
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional
from collections import Counter
from datetime import datetime, timezone
@dataclass
class ProvenanceMetadata:
"""Metadata tracking the provenance of a training pair."""
source_session_id: str
source_type: str # "trajectory", "curated", "augmented"
model: str
timestamp: str
quality_score: Optional[float] = None
excluded: bool = False
exclusion_reason: Optional[str] = None
# === Required provenance fields ===
REQUIRED_FIELDS = ["source", "source_session_id", "model", "timestamp"]
def to_dict(self) -> dict:
return {k: v for k, v in asdict(self).items() if v is not None}
# === Valid source types ===
VALID_SOURCES = {"curated", "trajectory", "augmentation", "backfill", "manual"}
def add_provenance(pair, source_session_id, source_type, model, quality_score=None):
"""Add provenance metadata to a training pair."""
provenance = ProvenanceMetadata(
def make_provenance(
source: str,
source_session_id: str,
model: str,
timestamp: Optional[str] = None,
extras: Optional[dict] = None,
) -> dict:
"""Create a provenance metadata dict.
Args:
source: One of curated, trajectory, augmentation, backfill, manual
source_session_id: Unique ID of the source session/trajectory
model: Model that generated the content
timestamp: ISO8601 timestamp (defaults to now)
extras: Optional additional metadata
Returns:
Provenance dict ready to attach to a training pair
"""
if source not in VALID_SOURCES:
raise ValueError(f"Invalid source '{source}'. Must be one of: {VALID_SOURCES}")
prov = {
"source": source,
"source_session_id": source_session_id,
"model": model,
"timestamp": timestamp or datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
}
if extras:
prov.update(extras)
return prov
def attach_provenance(
pair: dict,
source: str,
source_session_id: str,
model: str,
timestamp: Optional[str] = None,
extras: Optional[dict] = None,
) -> dict:
"""Attach provenance metadata to a training pair (mutates and returns).
The pair dict gets a 'provenance' key added. If provenance already exists,
it is NOT overwritten — use force=True in the extras to override.
Args:
pair: Training pair dict (ShareGPT format)
source: Source type
source_session_id: Session/trajectory ID
model: Model name
timestamp: ISO8601 timestamp
extras: Additional metadata
Returns:
The pair dict with provenance attached
"""
if "provenance" in pair and not (extras and extras.get("force")):
return pair
# Pop 'force' flag before passing to make_provenance
clean_extras = {k: v for k, v in (extras or {}).items() if k != "force"} or None
pair["provenance"] = make_provenance(
source=source,
source_session_id=source_session_id,
source_type=source_type,
model=model,
timestamp=time.strftime("%Y-%m-%dT%H:%M:%S"),
quality_score=quality_score
timestamp=timestamp,
extras=clean_extras,
)
if "provenance" not in pair:
pair["provenance"] = {}
pair["provenance"].update(provenance.to_dict())
return pair
def extract_provenance_from_trajectory(trajectory):
"""Extract provenance metadata from a trajectory file."""
def extract_trajectory_provenance(trajectory_entry: dict) -> dict:
"""Extract provenance metadata from a trajectory JSONL entry.
Trajectory entries may have fields like:
- id / session_id
- model
- started_at / timestamp
- source file path
Returns dict with extracted fields or sensible defaults.
"""
return {
"source_session_id": trajectory.get("id", "unknown"),
"source_type": "trajectory",
"model": trajectory.get("model", "unknown"),
"timestamp": trajectory.get("started_at", time.strftime("%Y-%m-%dT%H:%M:%S"))
"source_session_id": (
trajectory_entry.get("id")
or trajectory_entry.get("session_id")
or "unknown"
),
"model": trajectory_entry.get("model", "unknown"),
"timestamp": (
trajectory_entry.get("started_at")
or trajectory_entry.get("timestamp")
or trajectory_entry.get("created_at")
or datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
),
}
def validate_provenance(pair):
"""Validate that a pair has complete provenance metadata."""
errors = []
if "provenance" not in pair:
errors.append("Missing provenance metadata")
return False, errors
prov = pair["provenance"]
required = ["source_session_id", "source_type", "model", "timestamp"]
for field in required:
if field not in prov:
errors.append(f"Missing required field: {field}")
elif not prov[field]:
errors.append(f"Empty required field: {field}")
valid_types = {"trajectory", "curated", "augmented"}
if prov.get("source_type") not in valid_types:
errors.append(f"Invalid source_type: {prov.get('source_type')}")
return len(errors) == 0, errors
def pair_fingerprint(pair: dict) -> str:
"""Generate a stable fingerprint for a training pair.
Used for deduplication and tracking. Based on conversation content,
not metadata (so same content = same hash regardless of provenance).
"""
convos = pair.get("conversations", [])
content_parts = []
for c in convos:
if c.get("from") != "system": # Skip system prompt for fingerprint
content_parts.append(f"{c.get('from', '')}:{c.get('value', '')}")
content = "|".join(content_parts)
return hashlib.sha256(content.encode()).hexdigest()[:16]
def get_provenance_stats(pairs):
"""Compute statistics about provenance coverage."""
stats = {
"total_pairs": len(pairs),
def load_jsonl(path) -> list[dict]:
"""Load a JSONL file."""
path = Path(path)
entries = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
entries.append(json.loads(line))
return entries
def save_jsonl(path, entries: list[dict]):
"""Save entries to a JSONL file."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
for entry in entries:
f.write(json.dumps(entry) + "\n")
def validate_provenance(path) -> dict:
"""Validate provenance metadata on all pairs in a JSONL file.
Returns a report dict with:
- total: total pairs
- with_provenance: pairs that have provenance
- missing_provenance: pairs without provenance
- missing_fields: pairs with provenance but missing required fields
- invalid_source: pairs with unrecognized source type
- issues: list of specific issue descriptions
"""
path = Path(path)
if not path.exists():
return {"error": f"File not found: {path}", "total": 0}
entries = load_jsonl(path)
report = {
"total": len(entries),
"with_provenance": 0,
"by_source_type": {},
"by_model": {},
"excluded": 0,
"coverage_pct": 0.0
"missing_provenance": 0,
"missing_fields": 0,
"invalid_source": 0,
"issues": [],
}
for pair in pairs:
if "provenance" in pair:
stats["with_provenance"] += 1
prov = pair["provenance"]
st = prov.get("source_type", "unknown")
stats["by_source_type"][st] = stats["by_source_type"].get(st, 0) + 1
model = prov.get("model", "unknown")
stats["by_model"][model] = stats["by_model"].get(model, 0) + 1
if prov.get("excluded"):
stats["excluded"] += 1
if stats["total_pairs"] > 0:
stats["coverage_pct"] = round(stats["with_provenance"] / stats["total_pairs"] * 100, 1)
for i, entry in enumerate(entries):
prov = entry.get("provenance")
if not prov:
report["missing_provenance"] += 1
report["issues"].append(f"Pair {i} (id={entry.get('id', '?')}): no provenance")
continue
report["with_provenance"] += 1
# Check required fields
missing = [f for f in REQUIRED_FIELDS if f not in prov]
if missing:
report["missing_fields"] += 1
report["issues"].append(
f"Pair {i} (id={entry.get('id', '?')}): missing fields: {missing}"
)
# Check source validity
source = prov.get("source", "")
if source and source not in VALID_SOURCES:
report["invalid_source"] += 1
report["issues"].append(
f"Pair {i} (id={entry.get('id', '?')}): invalid source '{source}'"
)
report["coverage"] = (
report["with_provenance"] / report["total"] * 100 if report["total"] > 0 else 0
)
return report
def provenance_dashboard(path) -> str:
"""Generate a human-readable provenance dashboard for a dataset.
Shows:
- Pair count by model over time
- Pair count by source type
- Provenance coverage
- Model distribution
"""
path = Path(path)
if not path.exists():
return f"File not found: {path}"
entries = load_jsonl(path)
if not entries:
return "Empty dataset"
models = Counter()
sources = Counter()
timestamps = []
with_prov = 0
for entry in entries:
prov = entry.get("provenance")
if prov:
with_prov += 1
models[prov.get("model", "unknown")] += 1
sources[prov.get("source", "unknown")] += 1
ts = prov.get("timestamp", "")
if ts:
timestamps.append(ts[:10]) # Date only
else:
models["(no provenance)"] += 1
sources["(no provenance)"] += 1
coverage = with_prov / len(entries) * 100 if entries else 0
lines = [
"=" * 50,
"PROVENANCE DASHBOARD",
"=" * 50,
f"Total pairs: {len(entries)}",
f"Provenance coverage: {coverage:.1f}% ({with_prov}/{len(entries)})",
"",
"--- By Model ---",
]
for model, count in models.most_common():
pct = count / len(entries) * 100
lines.append(f" {model:<30} {count:>6} ({pct:.1f}%)")
lines.append("")
lines.append("--- By Source ---")
for source, count in sources.most_common():
pct = count / len(entries) * 100
lines.append(f" {source:<20} {count:>6} ({pct:.1f}%)")
if timestamps:
dates = Counter(timestamps)
lines.append("")
lines.append("--- By Date (top 10) ---")
for date, count in dates.most_common(10):
lines.append(f" {date:<12} {count:>6}")
return "\n".join(lines)
def backfill_provenance(
path,
source: str = "backfill",
model: str = "unknown",
output_path: Optional[str] = None,
) -> dict:
"""Add provenance to all pairs missing it.
Args:
path: Input JSONL file
source: Source type to use for backfilled pairs
model: Model name to use for backfilled pairs
output_path: Output path (defaults to overwriting input)
Returns:
Stats dict
"""
entries = load_jsonl(path)
stats = {"total": len(entries), "backfilled": 0, "already_had": 0}
for entry in entries:
if "provenance" not in entry:
session_id = entry.get("id", f"backfill-{stats['backfilled']}")
entry["provenance"] = make_provenance(
source=source,
source_session_id=session_id,
model=model,
)
stats["backfilled"] += 1
else:
stats["already_had"] += 1
out = Path(output_path) if output_path else Path(path)
save_jsonl(out, entries)
stats["output"] = str(out)
return stats
def print_provenance_report(stats):
"""Print a human-readable provenance report."""
print("Provenance Report")
print("=" * 50)
print(f"Total pairs: {stats['total_pairs']}")
print(f"With provenance: {stats['with_provenance']}")
print(f"Coverage: {stats['coverage_pct']}%")
print(f"Excluded: {stats['excluded']}")
print()
print("By source type:")
for st, count in sorted(stats["by_source_type"].items()):
print(f" {st}: {count}")
print()
print("By model:")
for model, count in sorted(stats["by_model"].items()):
print(f" {model}: {count}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Provenance tracking for training data")
sub = parser.add_subparsers(dest="command")
# validate
p_validate = sub.add_parser("validate", help="Validate provenance in a dataset")
p_validate.add_argument("input", help="Input JSONL file")
p_validate.add_argument("--json", action="store_true", help="Output as JSON")
# dashboard
p_dash = sub.add_parser("dashboard", help="Show provenance dashboard")
p_dash.add_argument("input", help="Input JSONL file")
# backfill
p_back = sub.add_parser("backfill", help="Add provenance to pairs missing it")
p_back.add_argument("input", help="Input JSONL file")
p_back.add_argument("--source", default="backfill", help="Source type")
p_back.add_argument("--model", default="unknown", help="Model name")
p_back.add_argument("--output", "-o", help="Output path (default: overwrite)")
args = parser.parse_args()
if args.command == "validate":
report = validate_provenance(args.input)
if args.json:
print(json.dumps(report, indent=2))
else:
print(f"Provenance Validation: {args.input}")
print(f" Total: {report['total']}")
print(f" With provenance: {report['with_provenance']}")
print(f" Missing provenance: {report['missing_provenance']}")
print(f" Missing fields: {report['missing_fields']}")
print(f" Invalid source: {report['invalid_source']}")
print(f" Coverage: {report.get('coverage', 0):.1f}%")
if report["issues"]:
print(f"\n Issues ({len(report['issues'])}):")
for issue in report["issues"][:20]:
print(f" {issue}")
elif args.command == "dashboard":
print(provenance_dashboard(args.input))
elif args.command == "backfill":
stats = backfill_provenance(args.input, args.source, args.model, args.output)
print(f"Backfill complete:")
print(f" Total: {stats['total']}")
print(f" Backfilled: {stats['backfilled']}")
print(f" Already had provenance: {stats['already_had']}")
print(f" Output: {stats['output']}")
else:
parser.print_help()