#!/usr/bin/env python3 """ training_pair_provenance.py — Provenance tracking for training data pairs. Every training pair should carry metadata about where it came from: - Which session/trajectory produced it - Which model generated it - When it was created - What source type (curated, trajectory, augmentation) This module provides utilities to: 1. Attach provenance metadata to training pairs 2. Validate that provenance exists 3. Generate provenance statistics/dashboards 4. Backfill provenance on existing pairs Usage: from training_pair_provenance import attach_provenance, validate_provenance, provenance_dashboard # Attach provenance to a pair pair = attach_provenance(pair, source="trajectory", session_id="abc123", model="hermes3:latest") # Validate provenance on a dataset report = validate_provenance("data/curated_dataset.jsonl") # Generate dashboard print(provenance_dashboard("data/merged_training_data.jsonl")) """ import json import time import hashlib from pathlib import Path from typing import Optional from collections import Counter from datetime import datetime, timezone # === Required provenance fields === REQUIRED_FIELDS = ["source", "source_session_id", "model", "timestamp"] # === Valid source types === VALID_SOURCES = {"curated", "trajectory", "augmentation", "backfill", "manual", "unknown"} def make_provenance( source: str, source_session_id: str, model: str, timestamp: Optional[str] = None, extras: Optional[dict] = None, ) -> dict: """Create a provenance metadata dict. Args: source: One of curated, trajectory, augmentation, backfill, manual source_session_id: Unique ID of the source session/trajectory model: Model that generated the content timestamp: ISO8601 timestamp (defaults to now) extras: Optional additional metadata Returns: Provenance dict ready to attach to a training pair """ if source not in VALID_SOURCES: raise ValueError(f"Invalid source '{source}'. Must be one of: {VALID_SOURCES}") prov = { "source": source, "source_session_id": source_session_id, "model": model, "timestamp": timestamp or datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), } if extras: prov.update(extras) return prov def attach_provenance( pair: dict, source: str, source_session_id: str, model: str, timestamp: Optional[str] = None, extras: Optional[dict] = None, ) -> dict: """Attach provenance metadata to a training pair (mutates and returns). The pair dict gets a 'provenance' key added. If provenance already exists, it is NOT overwritten — use force=True in the extras to override. Args: pair: Training pair dict (ShareGPT format) source: Source type source_session_id: Session/trajectory ID model: Model name timestamp: ISO8601 timestamp extras: Additional metadata Returns: The pair dict with provenance attached """ if "provenance" in pair and not (extras and extras.get("force")): return pair # Pop 'force' flag before passing to make_provenance clean_extras = {k: v for k, v in (extras or {}).items() if k != "force"} or None pair["provenance"] = make_provenance( source=source, source_session_id=source_session_id, model=model, timestamp=timestamp, extras=clean_extras, ) return pair def extract_trajectory_provenance(trajectory_entry: dict) -> dict: """Extract provenance metadata from a trajectory JSONL entry. Trajectory entries may have fields like: - id / session_id - model - started_at / timestamp - source file path Returns dict with extracted fields or sensible defaults. """ return { "source_session_id": ( trajectory_entry.get("id") or trajectory_entry.get("session_id") or "unknown" ), "model": trajectory_entry.get("model", "unknown"), "timestamp": ( trajectory_entry.get("started_at") or trajectory_entry.get("timestamp") or trajectory_entry.get("created_at") or datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") ), } def pair_fingerprint(pair: dict) -> str: """Generate a stable fingerprint for a training pair. Used for deduplication and tracking. Based on conversation content, not metadata (so same content = same hash regardless of provenance). """ convos = pair.get("conversations", []) content_parts = [] for c in convos: if c.get("from") != "system": # Skip system prompt for fingerprint content_parts.append(f"{c.get('from', '')}:{c.get('value', '')}") content = "|".join(content_parts) return hashlib.sha256(content.encode()).hexdigest()[:16] def load_jsonl(path) -> list[dict]: """Load a JSONL file.""" path = Path(path) entries = [] with open(path) as f: for line in f: line = line.strip() if line: entries.append(json.loads(line)) return entries def save_jsonl(entries_or_path, path_or_entries=None): """Save entries to a JSONL file. Accepts (path, entries) or (entries, path).""" if isinstance(entries_or_path, (list, tuple)) and path_or_entries is not None: entries = entries_or_path path = Path(path_or_entries) elif isinstance(path_or_entries, (list, tuple)): entries = path_or_entries path = Path(entries_or_path) else: entries = entries_or_path if isinstance(entries_or_path, list) else [] path = Path(path_or_entries) if path_or_entries else Path(entries_or_path) path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as f: for entry in entries: f.write(json.dumps(entry) + "\n") def validate_provenance(path) -> dict: """Validate provenance metadata on all pairs in a JSONL file. Returns a report dict with: - total: total pairs - with_provenance: pairs that have provenance - missing_provenance: pairs without provenance - missing_fields: pairs with provenance but missing required fields - invalid_source: pairs with unrecognized source type - issues: list of specific issue descriptions """ path = Path(path) if not path.exists(): return {"error": f"File not found: {path}", "total": 0} entries = load_jsonl(path) report = { "total": len(entries), "with_provenance": 0, "missing_provenance": 0, "missing_fields": 0, "invalid_source": 0, "issues": [], } for i, entry in enumerate(entries): prov = entry.get("provenance") if not prov: report["missing_provenance"] += 1 report["issues"].append(f"Pair {i} (id={entry.get('id', '?')}): no provenance") continue report["with_provenance"] += 1 # Check required fields missing = [f for f in REQUIRED_FIELDS if f not in prov] if missing: report["missing_fields"] += 1 report["issues"].append( f"Pair {i} (id={entry.get('id', '?')}): missing fields: {missing}" ) # Check source validity source = prov.get("source", "") if source and source not in VALID_SOURCES: report["invalid_source"] += 1 report["issues"].append( f"Pair {i} (id={entry.get('id', '?')}): invalid source '{source}'" ) report["coverage"] = ( report["with_provenance"] / report["total"] * 100 if report["total"] > 0 else 0 ) return report def provenance_dashboard(path) -> str: """Generate a human-readable provenance dashboard for a dataset. Shows: - Pair count by model over time - Pair count by source type - Provenance coverage - Model distribution """ path = Path(path) if not path.exists(): return f"File not found: {path}" entries = load_jsonl(path) if not entries: return "Empty dataset" models = Counter() sources = Counter() timestamps = [] with_prov = 0 for entry in entries: prov = entry.get("provenance") if prov: with_prov += 1 models[prov.get("model", "unknown")] += 1 sources[prov.get("source", "unknown")] += 1 ts = prov.get("timestamp", "") if ts: timestamps.append(ts[:10]) # Date only else: models["(no provenance)"] += 1 sources["(no provenance)"] += 1 coverage = with_prov / len(entries) * 100 if entries else 0 lines = [ "=" * 50, "PROVENANCE DASHBOARD", "=" * 50, f"Total pairs: {len(entries)}", f"Provenance coverage: {coverage:.1f}% ({with_prov}/{len(entries)})", "", "--- By Model ---", ] for model, count in models.most_common(): pct = count / len(entries) * 100 lines.append(f" {model:<30} {count:>6} ({pct:.1f}%)") lines.append("") lines.append("--- By Source ---") for source, count in sources.most_common(): pct = count / len(entries) * 100 lines.append(f" {source:<20} {count:>6} ({pct:.1f}%)") if timestamps: dates = Counter(timestamps) lines.append("") lines.append("--- By Date (top 10) ---") for date, count in dates.most_common(10): lines.append(f" {date:<12} {count:>6}") return "\n".join(lines) def backfill_provenance( path, source: str = "backfill", model: str = "unknown", output_path: Optional[str] = None, ) -> dict: """Add provenance to all pairs missing it. Args: path: Input JSONL file source: Source type to use for backfilled pairs model: Model name to use for backfilled pairs output_path: Output path (defaults to overwriting input) Returns: Stats dict """ entries = load_jsonl(path) stats = {"total": len(entries), "backfilled": 0, "already_had": 0} for entry in entries: if "provenance" not in entry: session_id = entry.get("id", f"backfill-{stats['backfilled']}") entry["provenance"] = make_provenance( source=source, source_session_id=session_id, model=model, ) stats["backfilled"] += 1 else: stats["already_had"] += 1 out = Path(output_path) if output_path else Path(path) save_jsonl(out, entries) stats["output"] = str(out) return stats class ProvenanceTracker: """Track provenance metadata for training pairs.""" def __init__(self): self.stats = { "total_pairs": 0, "pairs_with_provenance": 0, "pairs_without_provenance": 0, } def generate_pair_id(self, pair: dict) -> str: """Generate a deterministic ID for a pair.""" content = json.dumps(pair, sort_keys=True) return hashlib.sha256(content.encode()).hexdigest()[:16] def process_pair(self, pair: dict) -> dict: """Process a pair, adding provenance if missing.""" self.stats["total_pairs"] += 1 if "source_session_id" in pair and pair["source_session_id"]: self.stats["pairs_with_provenance"] += 1 else: self.stats["pairs_without_provenance"] += 1 pair = attach_provenance(pair, source="unknown", source_session_id="unknown", model="unknown") if "pair_id" not in pair: pair["pair_id"] = self.generate_pair_id(pair) return pair def process_file(self, input_path: str, output_path: str = None) -> dict: """Process a JSONL file, adding provenance to all pairs.""" pairs = load_jsonl(input_path) processed = [self.process_pair(p) for p in pairs] if output_path: save_jsonl(processed, output_path) return self.stats def add_provenance(self, pair: dict, source_session_id: str, model: str, source: str = "curated", timestamp: str = None, extras: dict = None) -> dict: """Add provenance metadata to a pair.""" import hashlib as _hl convos = pair.get("conversations", []) content_parts = [] for c in convos: if c.get("from") != "system": content_parts.append(f"{c.get('from', '')}:{c.get('value', '')}") content_hash = _hl.sha256("|".join(content_parts).encode()).hexdigest()[:16] pair["provenance"] = { "source": source, "source_session_id": source_session_id, "model": model, "timestamp": timestamp or datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), "content_hash": content_hash, } if extras: pair["provenance"].update(extras) return pair def extract_provenance_from_existing(self, pair: dict) -> dict: """Extract provenance from existing pair fields.""" return { "source": "curated", "source_session_id": pair.get("id", "unknown"), "model": pair.get("model", "unknown"), "timestamp": pair.get("started_at", pair.get("timestamp", datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"))), "content_hash": pair_fingerprint(pair), } def filter_by_provenance(self, pairs: list, exclude_models: list = None, exclude_sources: list = None) -> list: """Filter pairs by provenance metadata.""" exclude_models = set(exclude_models or []) exclude_sources = set(exclude_sources or []) filtered = [] for pair in pairs: prov = pair.get("provenance", {}) model = prov.get("model", "") source = prov.get("source", "") if model in exclude_models: self.stats["excluded"] = self.stats.get("excluded", 0) + 1 continue if source in exclude_sources: self.stats["excluded"] = self.stats.get("excluded", 0) + 1 continue filtered.append(pair) return filtered def generate_report(self) -> str: """Generate a human-readable report of tracking stats.""" lines = [ "Training Pair Provenance Report", "=" * 40, f"Total pairs: {self.stats['total_pairs']}", f"Pairs with provenance: {self.stats['pairs_with_provenance']}", f"Pairs without provenance: {self.stats['pairs_without_provenance']}", ] by_model = self.stats.get("by_model", {}) if by_model: lines.append("") lines.append("By model:") for model, count in sorted(by_model.items()): lines.append(f" {model}: {count}") by_source = self.stats.get("by_source", {}) if by_source: lines.append("") lines.append("By source:") for source, count in sorted(by_source.items()): lines.append(f" {source}: {count}") excluded = self.stats.get("excluded", 0) if excluded: lines.append(f"\nExcluded: {excluded}") return "\n".join(lines) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Provenance tracking for training data") sub = parser.add_subparsers(dest="command") # validate p_validate = sub.add_parser("validate", help="Validate provenance in a dataset") p_validate.add_argument("input", help="Input JSONL file") p_validate.add_argument("--json", action="store_true", help="Output as JSON") # dashboard p_dash = sub.add_parser("dashboard", help="Show provenance dashboard") p_dash.add_argument("input", help="Input JSONL file") # backfill p_back = sub.add_parser("backfill", help="Add provenance to pairs missing it") p_back.add_argument("input", help="Input JSONL file") p_back.add_argument("--source", default="backfill", help="Source type") p_back.add_argument("--model", default="unknown", help="Model name") p_back.add_argument("--output", "-o", help="Output path (default: overwrite)") args = parser.parse_args() if args.command == "validate": report = validate_provenance(args.input) if args.json: print(json.dumps(report, indent=2)) else: print(f"Provenance Validation: {args.input}") print(f" Total: {report['total']}") print(f" With provenance: {report['with_provenance']}") print(f" Missing provenance: {report['missing_provenance']}") print(f" Missing fields: {report['missing_fields']}") print(f" Invalid source: {report['invalid_source']}") print(f" Coverage: {report.get('coverage', 0):.1f}%") if report["issues"]: print(f"\n Issues ({len(report['issues'])}):") for issue in report["issues"][:20]: print(f" {issue}") elif args.command == "dashboard": print(provenance_dashboard(args.input)) elif args.command == "backfill": stats = backfill_provenance(args.input, args.source, args.model, args.output) print(f"Backfill complete:") print(f" Total: {stats['total']}") print(f" Backfilled: {stats['backfilled']}") print(f" Already had provenance: {stats['already_had']}") print(f" Output: {stats['output']}") else: parser.print_help()