From 8e14c1b7ec856d7137cd8e05ab2d575bc0b30b65 Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Wed, 15 Apr 2026 03:29:01 +0000 Subject: [PATCH] feat: add training pair provenance tracker (#691) --- scripts/training_provenance.py | 260 +++++++++++++++++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 scripts/training_provenance.py diff --git a/scripts/training_provenance.py b/scripts/training_provenance.py new file mode 100644 index 00000000..6de9c016 --- /dev/null +++ b/scripts/training_provenance.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +""" +[PROVENANCE] Training Pair Provenance Tracker +Part of the Timmy Foundation tooling. + +Adds, filters, and reports provenance metadata for JSONL training pairs. +Tracks source_session_id, model, and timestamp for quality auditing. + +Usage: + # Tag pairs with provenance + python3 scripts/training_provenance.py tag input.jsonl -o tagged.jsonl \ + --session abc123 --model nous/hermes-3 + + # Filter by model (exclude Anthropic-sourced) + python3 scripts/training_provenance.py filter input.jsonl -o filtered.jsonl \ + --exclude-model anthropic + + # Report: pair count by source model + python3 scripts/training_provenance.py report input.jsonl + + # Pipe support + cat pairs.jsonl | python3 scripts/training_provenance.py report - +""" + +import sys +import json +import argparse +from datetime import datetime, timezone +from collections import Counter +from typing import Dict, Any, Optional + + +PROVENANCE_KEYS = ["source_session_id", "source_model", "source_timestamp"] + + +def tag_pair(pair: Dict[str, Any], session_id: Optional[str] = None, + model: Optional[str] = None) -> Dict[str, Any]: + """Add provenance metadata to a training pair.""" + meta = pair.get("_provenance", {}) + + if session_id: + meta["source_session_id"] = session_id + if model: + meta["source_model"] = model + meta["source_timestamp"] = datetime.now(timezone.utc).isoformat() + + if meta: + pair["_provenance"] = meta + + return pair + + +def filter_pairs(input_path: str, output_path: str, + include_models: Optional[list] = None, + exclude_models: Optional[list] = None, + min_session_age: Optional[str] = None) -> Dict[str, Any]: + """Filter pairs by provenance metadata.""" + kept = [] + removed = [] + errors = 0 + + source = sys.stdin if input_path == "-" else open(input_path, "r") + + try: + for line in source: + line = line.strip() + if not line: + continue + try: + pair = json.loads(line) + except json.JSONDecodeError: + errors += 1 + continue + + prov = pair.get("_provenance", {}) + model = prov.get("source_model", "unknown") + + should_keep = True + + if include_models: + should_keep = should_keep and model in include_models + + if exclude_models: + should_keep = should_keep and model not in exclude_models + + if should_keep: + kept.append(pair) + else: + removed.append(pair) + finally: + if source is not sys.stdin: + source.close() + + # Write output + if output_path: + out = sys.stdout if output_path == "-" else open(output_path, "w") + try: + for pair in kept: + out.write(json.dumps(pair, ensure_ascii=False) + "\n") + finally: + if out is not sys.stdin: + out.close() + + return { + "total": len(kept) + len(removed), + "kept": len(kept), + "filtered_out": len(removed), + "errors": errors, + } + + +def report(input_path: str) -> Dict[str, Any]: + """Report pair counts by source model and session.""" + model_counts = Counter() + session_counts = Counter() + tagged = 0 + untagged = 0 + total = 0 + errors = 0 + + source = sys.stdin if input_path == "-" else open(input_path, "r") + + try: + for line in source: + line = line.strip() + if not line: + continue + try: + pair = json.loads(line) + except json.JSONDecodeError: + errors += 1 + continue + + total += 1 + prov = pair.get("_provenance", {}) + + if prov: + tagged += 1 + model = prov.get("source_model", "unknown") + session = prov.get("source_session_id", "unknown") + model_counts[model] += 1 + session_counts[session] += 1 + else: + untagged += 1 + finally: + if source is not sys.stdin: + source.close() + + return { + "total": total, + "tagged": tagged, + "untagged": untagged, + "tag_rate": round(tagged / max(total, 1) * 100, 1), + "by_model": dict(model_counts.most_common(20)), + "by_session": dict(session_counts.most_common(10)), + "errors": errors, + } + + +def stamp_command(input_path: str, output_path: str, + session_id: Optional[str], model: Optional[str]) -> Dict[str, Any]: + """Tag all pairs in a file with provenance metadata.""" + tagged = 0 + skipped = 0 + errors = 0 + + source = sys.stdin if input_path == "-" else open(input_path, "r") + out = sys.stdout if output_path == "-" else open(output_path, "w") + + try: + for line in source: + line = line.strip() + if not line: + continue + try: + pair = json.loads(line) + except json.JSONDecodeError: + errors += 1 + continue + + # Skip if already tagged with same model + existing = pair.get("_provenance", {}) + if existing.get("source_model") == model and existing.get("source_session_id") == session_id: + skipped += 1 + out.write(line + "\n") + continue + + pair = tag_pair(pair, session_id=session_id, model=model) + out.write(json.dumps(pair, ensure_ascii=False) + "\n") + tagged += 1 + finally: + if source is not sys.stdin: + source.close() + if out is not sys.stdin: + out.close() + + return {"tagged": tagged, "skipped": skipped, "errors": errors} + + +def main(): + parser = argparse.ArgumentParser(description="Training pair provenance tracking") + sub = parser.add_subparsers(dest="command", required=True) + + # tag subcommand + tag_p = sub.add_parser("tag", help="Tag pairs with provenance metadata") + tag_p.add_argument("input", help="Input JSONL file (use - for stdin)") + tag_p.add_argument("-o", "--output", default="-", help="Output JSONL file") + tag_p.add_argument("--session", help="Source session ID") + tag_p.add_argument("--model", help="Source model name") + + # filter subcommand + filt_p = sub.add_parser("filter", help="Filter pairs by provenance") + filt_p.add_argument("input", help="Input JSONL file (use - for stdin)") + filt_p.add_argument("-o", "--output", default="-", help="Output JSONL file") + filt_p.add_argument("--include-model", action="append", help="Only include these models") + filt_p.add_argument("--exclude-model", action="append", help="Exclude these models") + + # report subcommand + rpt_p = sub.add_parser("report", help="Report provenance statistics") + rpt_p.add_argument("input", help="Input JSONL file (use - for stdin)") + + args = parser.parse_args() + + if args.command == "tag": + result = stamp_command(args.input, args.output, args.session, args.model) + print(f"Tagged: {result['tagged']} Skipped: {result['skipped']} Errors: {result['errors']}", file=sys.stderr) + + elif args.command == "filter": + result = filter_pairs( + args.input, args.output, + include_models=args.include_model, + exclude_models=args.exclude_model, + ) + print(f"Total: {result['total']} Kept: {result['kept']} Filtered: {result['filtered_out']}", file=sys.stderr) + + elif args.command == "report": + result = report(args.input) + print(f"Training Pair Provenance Report", file=sys.stderr) + print(f"{'='*40}", file=sys.stderr) + print(f"Total pairs: {result['total']}", file=sys.stderr) + print(f"Tagged: {result['tagged']} ({result['tag_rate']}%)", file=sys.stderr) + print(f"Untagged: {result['untagged']}", file=sys.stderr) + + if result['by_model']: + print(f"\nBy source model:", file=sys.stderr) + for model, count in result['by_model'].items(): + print(f" {model}: {count}", file=sys.stderr) + + if result['by_session']: + print(f"\nBy source session (top 10):", file=sys.stderr) + for session, count in result['by_session'].items(): + session_short = session[:12] + "..." if len(session) > 12 else session + print(f" {session_short}: {count}", file=sys.stderr) + + # Output JSON to stdout + print(json.dumps(result, indent=2)) + + +if __name__ == "__main__": + main()