diff --git a/training/training_pair_provenance.py b/training/training_pair_provenance.py index 00ec5220..3af41eb6 100644 --- a/training/training_pair_provenance.py +++ b/training/training_pair_provenance.py @@ -1,115 +1,397 @@ #!/usr/bin/env python3 """ -Training Pair Provenance Tracking +training_pair_provenance.py — Provenance tracking for training data pairs. -Tracks the origin, model, and quality metadata for each training pair. -Integrates with ingest_trajectories.py and build_curated.py. +Every training pair should carry metadata about where it came from: +- Which session/trajectory produced it +- Which model generated it +- When it was created +- What source type (curated, trajectory, augmentation) + +This module provides utilities to: +1. Attach provenance metadata to training pairs +2. Validate that provenance exists +3. Generate provenance statistics/dashboards +4. Backfill provenance on existing pairs + +Usage: + from training_pair_provenance import attach_provenance, validate_provenance, provenance_dashboard + + # Attach provenance to a pair + pair = attach_provenance(pair, source="trajectory", session_id="abc123", model="hermes3:latest") + + # Validate provenance on a dataset + report = validate_provenance("data/curated_dataset.jsonl") + + # Generate dashboard + print(provenance_dashboard("data/merged_training_data.jsonl")) """ import json import time +import hashlib from pathlib import Path -from dataclasses import dataclass, asdict from typing import Optional +from collections import Counter +from datetime import datetime, timezone -@dataclass -class ProvenanceMetadata: - """Metadata tracking the provenance of a training pair.""" - source_session_id: str - source_type: str # "trajectory", "curated", "augmented" - model: str - timestamp: str - quality_score: Optional[float] = None - excluded: bool = False - exclusion_reason: Optional[str] = None +# === Required provenance fields === +REQUIRED_FIELDS = ["source", "source_session_id", "model", "timestamp"] - def to_dict(self) -> dict: - return {k: v for k, v in asdict(self).items() if v is not None} +# === Valid source types === +VALID_SOURCES = {"curated", "trajectory", "augmentation", "backfill", "manual"} -def add_provenance(pair, source_session_id, source_type, model, quality_score=None): - """Add provenance metadata to a training pair.""" - provenance = ProvenanceMetadata( +def make_provenance( + source: str, + source_session_id: str, + model: str, + timestamp: Optional[str] = None, + extras: Optional[dict] = None, +) -> dict: + """Create a provenance metadata dict. + + Args: + source: One of curated, trajectory, augmentation, backfill, manual + source_session_id: Unique ID of the source session/trajectory + model: Model that generated the content + timestamp: ISO8601 timestamp (defaults to now) + extras: Optional additional metadata + + Returns: + Provenance dict ready to attach to a training pair + """ + if source not in VALID_SOURCES: + raise ValueError(f"Invalid source '{source}'. Must be one of: {VALID_SOURCES}") + + prov = { + "source": source, + "source_session_id": source_session_id, + "model": model, + "timestamp": timestamp or datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + } + + if extras: + prov.update(extras) + + return prov + + +def attach_provenance( + pair: dict, + source: str, + source_session_id: str, + model: str, + timestamp: Optional[str] = None, + extras: Optional[dict] = None, +) -> dict: + """Attach provenance metadata to a training pair (mutates and returns). + + The pair dict gets a 'provenance' key added. If provenance already exists, + it is NOT overwritten — use force=True in the extras to override. + + Args: + pair: Training pair dict (ShareGPT format) + source: Source type + source_session_id: Session/trajectory ID + model: Model name + timestamp: ISO8601 timestamp + extras: Additional metadata + + Returns: + The pair dict with provenance attached + """ + if "provenance" in pair and not (extras and extras.get("force")): + return pair + + # Pop 'force' flag before passing to make_provenance + clean_extras = {k: v for k, v in (extras or {}).items() if k != "force"} or None + + pair["provenance"] = make_provenance( + source=source, source_session_id=source_session_id, - source_type=source_type, model=model, - timestamp=time.strftime("%Y-%m-%dT%H:%M:%S"), - quality_score=quality_score + timestamp=timestamp, + extras=clean_extras, ) - if "provenance" not in pair: - pair["provenance"] = {} - pair["provenance"].update(provenance.to_dict()) return pair -def extract_provenance_from_trajectory(trajectory): - """Extract provenance metadata from a trajectory file.""" +def extract_trajectory_provenance(trajectory_entry: dict) -> dict: + """Extract provenance metadata from a trajectory JSONL entry. + + Trajectory entries may have fields like: + - id / session_id + - model + - started_at / timestamp + - source file path + + Returns dict with extracted fields or sensible defaults. + """ return { - "source_session_id": trajectory.get("id", "unknown"), - "source_type": "trajectory", - "model": trajectory.get("model", "unknown"), - "timestamp": trajectory.get("started_at", time.strftime("%Y-%m-%dT%H:%M:%S")) + "source_session_id": ( + trajectory_entry.get("id") + or trajectory_entry.get("session_id") + or "unknown" + ), + "model": trajectory_entry.get("model", "unknown"), + "timestamp": ( + trajectory_entry.get("started_at") + or trajectory_entry.get("timestamp") + or trajectory_entry.get("created_at") + or datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + ), } -def validate_provenance(pair): - """Validate that a pair has complete provenance metadata.""" - errors = [] - if "provenance" not in pair: - errors.append("Missing provenance metadata") - return False, errors - prov = pair["provenance"] - required = ["source_session_id", "source_type", "model", "timestamp"] - for field in required: - if field not in prov: - errors.append(f"Missing required field: {field}") - elif not prov[field]: - errors.append(f"Empty required field: {field}") - valid_types = {"trajectory", "curated", "augmented"} - if prov.get("source_type") not in valid_types: - errors.append(f"Invalid source_type: {prov.get('source_type')}") - return len(errors) == 0, errors +def pair_fingerprint(pair: dict) -> str: + """Generate a stable fingerprint for a training pair. + + Used for deduplication and tracking. Based on conversation content, + not metadata (so same content = same hash regardless of provenance). + """ + convos = pair.get("conversations", []) + content_parts = [] + for c in convos: + if c.get("from") != "system": # Skip system prompt for fingerprint + content_parts.append(f"{c.get('from', '')}:{c.get('value', '')}") + content = "|".join(content_parts) + return hashlib.sha256(content.encode()).hexdigest()[:16] -def get_provenance_stats(pairs): - """Compute statistics about provenance coverage.""" - stats = { - "total_pairs": len(pairs), +def load_jsonl(path) -> list[dict]: + """Load a JSONL file.""" + path = Path(path) + entries = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: + entries.append(json.loads(line)) + return entries + + +def save_jsonl(path, entries: list[dict]): + """Save entries to a JSONL file.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + for entry in entries: + f.write(json.dumps(entry) + "\n") + + +def validate_provenance(path) -> dict: + """Validate provenance metadata on all pairs in a JSONL file. + + Returns a report dict with: + - total: total pairs + - with_provenance: pairs that have provenance + - missing_provenance: pairs without provenance + - missing_fields: pairs with provenance but missing required fields + - invalid_source: pairs with unrecognized source type + - issues: list of specific issue descriptions + """ + path = Path(path) + if not path.exists(): + return {"error": f"File not found: {path}", "total": 0} + + entries = load_jsonl(path) + report = { + "total": len(entries), "with_provenance": 0, - "by_source_type": {}, - "by_model": {}, - "excluded": 0, - "coverage_pct": 0.0 + "missing_provenance": 0, + "missing_fields": 0, + "invalid_source": 0, + "issues": [], } - for pair in pairs: - if "provenance" in pair: - stats["with_provenance"] += 1 - prov = pair["provenance"] - st = prov.get("source_type", "unknown") - stats["by_source_type"][st] = stats["by_source_type"].get(st, 0) + 1 - model = prov.get("model", "unknown") - stats["by_model"][model] = stats["by_model"].get(model, 0) + 1 - if prov.get("excluded"): - stats["excluded"] += 1 - if stats["total_pairs"] > 0: - stats["coverage_pct"] = round(stats["with_provenance"] / stats["total_pairs"] * 100, 1) + + for i, entry in enumerate(entries): + prov = entry.get("provenance") + if not prov: + report["missing_provenance"] += 1 + report["issues"].append(f"Pair {i} (id={entry.get('id', '?')}): no provenance") + continue + + report["with_provenance"] += 1 + + # Check required fields + missing = [f for f in REQUIRED_FIELDS if f not in prov] + if missing: + report["missing_fields"] += 1 + report["issues"].append( + f"Pair {i} (id={entry.get('id', '?')}): missing fields: {missing}" + ) + + # Check source validity + source = prov.get("source", "") + if source and source not in VALID_SOURCES: + report["invalid_source"] += 1 + report["issues"].append( + f"Pair {i} (id={entry.get('id', '?')}): invalid source '{source}'" + ) + + report["coverage"] = ( + report["with_provenance"] / report["total"] * 100 if report["total"] > 0 else 0 + ) + return report + + +def provenance_dashboard(path) -> str: + """Generate a human-readable provenance dashboard for a dataset. + + Shows: + - Pair count by model over time + - Pair count by source type + - Provenance coverage + - Model distribution + """ + path = Path(path) + if not path.exists(): + return f"File not found: {path}" + + entries = load_jsonl(path) + if not entries: + return "Empty dataset" + + models = Counter() + sources = Counter() + timestamps = [] + with_prov = 0 + + for entry in entries: + prov = entry.get("provenance") + if prov: + with_prov += 1 + models[prov.get("model", "unknown")] += 1 + sources[prov.get("source", "unknown")] += 1 + ts = prov.get("timestamp", "") + if ts: + timestamps.append(ts[:10]) # Date only + else: + models["(no provenance)"] += 1 + sources["(no provenance)"] += 1 + + coverage = with_prov / len(entries) * 100 if entries else 0 + + lines = [ + "=" * 50, + "PROVENANCE DASHBOARD", + "=" * 50, + f"Total pairs: {len(entries)}", + f"Provenance coverage: {coverage:.1f}% ({with_prov}/{len(entries)})", + "", + "--- By Model ---", + ] + for model, count in models.most_common(): + pct = count / len(entries) * 100 + lines.append(f" {model:<30} {count:>6} ({pct:.1f}%)") + + lines.append("") + lines.append("--- By Source ---") + for source, count in sources.most_common(): + pct = count / len(entries) * 100 + lines.append(f" {source:<20} {count:>6} ({pct:.1f}%)") + + if timestamps: + dates = Counter(timestamps) + lines.append("") + lines.append("--- By Date (top 10) ---") + for date, count in dates.most_common(10): + lines.append(f" {date:<12} {count:>6}") + + return "\n".join(lines) + + +def backfill_provenance( + path, + source: str = "backfill", + model: str = "unknown", + output_path: Optional[str] = None, +) -> dict: + """Add provenance to all pairs missing it. + + Args: + path: Input JSONL file + source: Source type to use for backfilled pairs + model: Model name to use for backfilled pairs + output_path: Output path (defaults to overwriting input) + + Returns: + Stats dict + """ + entries = load_jsonl(path) + stats = {"total": len(entries), "backfilled": 0, "already_had": 0} + + for entry in entries: + if "provenance" not in entry: + session_id = entry.get("id", f"backfill-{stats['backfilled']}") + entry["provenance"] = make_provenance( + source=source, + source_session_id=session_id, + model=model, + ) + stats["backfilled"] += 1 + else: + stats["already_had"] += 1 + + out = Path(output_path) if output_path else Path(path) + save_jsonl(out, entries) + stats["output"] = str(out) return stats -def print_provenance_report(stats): - """Print a human-readable provenance report.""" - print("Provenance Report") - print("=" * 50) - print(f"Total pairs: {stats['total_pairs']}") - print(f"With provenance: {stats['with_provenance']}") - print(f"Coverage: {stats['coverage_pct']}%") - print(f"Excluded: {stats['excluded']}") - print() - print("By source type:") - for st, count in sorted(stats["by_source_type"].items()): - print(f" {st}: {count}") - print() - print("By model:") - for model, count in sorted(stats["by_model"].items()): - print(f" {model}: {count}") +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Provenance tracking for training data") + sub = parser.add_subparsers(dest="command") + + # validate + p_validate = sub.add_parser("validate", help="Validate provenance in a dataset") + p_validate.add_argument("input", help="Input JSONL file") + p_validate.add_argument("--json", action="store_true", help="Output as JSON") + + # dashboard + p_dash = sub.add_parser("dashboard", help="Show provenance dashboard") + p_dash.add_argument("input", help="Input JSONL file") + + # backfill + p_back = sub.add_parser("backfill", help="Add provenance to pairs missing it") + p_back.add_argument("input", help="Input JSONL file") + p_back.add_argument("--source", default="backfill", help="Source type") + p_back.add_argument("--model", default="unknown", help="Model name") + p_back.add_argument("--output", "-o", help="Output path (default: overwrite)") + + args = parser.parse_args() + + if args.command == "validate": + report = validate_provenance(args.input) + if args.json: + print(json.dumps(report, indent=2)) + else: + print(f"Provenance Validation: {args.input}") + print(f" Total: {report['total']}") + print(f" With provenance: {report['with_provenance']}") + print(f" Missing provenance: {report['missing_provenance']}") + print(f" Missing fields: {report['missing_fields']}") + print(f" Invalid source: {report['invalid_source']}") + print(f" Coverage: {report.get('coverage', 0):.1f}%") + if report["issues"]: + print(f"\n Issues ({len(report['issues'])}):") + for issue in report["issues"][:20]: + print(f" {issue}") + + elif args.command == "dashboard": + print(provenance_dashboard(args.input)) + + elif args.command == "backfill": + stats = backfill_provenance(args.input, args.source, args.model, args.output) + print(f"Backfill complete:") + print(f" Total: {stats['total']}") + print(f" Backfilled: {stats['backfilled']}") + print(f" Already had provenance: {stats['already_had']}") + print(f" Output: {stats['output']}") + + else: + parser.print_help()