Files
timmy-config/scripts/provenance_validate.py
2026-04-21 21:23:48 -04:00

137 lines
4.1 KiB
Python

#!/usr/bin/env python3
"""
provenance_validate.py — Validate provenance metadata on training data.
Checks:
- All pairs have required provenance fields
- source_session_id, model, timestamp present
- Coverage report by model and source
Usage:
python3 provenance_validate.py training-data/*.jsonl
python3 provenance_validate.py --threshold 80 training-data/*.jsonl
"""
import json
import sys
from pathlib import Path
from typing import List
REQUIRED_FIELDS = ["source_session_id", "model", "timestamp"]
def validate_file(filepath: str) -> dict:
"""Validate provenance on a single JSONL file."""
pairs = []
with open(filepath) as f:
for line in f:
if line.strip():
pairs.append(json.loads(line))
total = len(pairs)
with_provenance = 0
missing_by_field = {f: 0 for f in REQUIRED_FIELDS}
by_model = {}
by_source = {}
for pair in pairs:
has_all = True
for field in REQUIRED_FIELDS:
if field not in pair or not pair[field]:
missing_by_field[field] += 1
has_all = False
if has_all:
with_provenance += 1
model = pair.get("model", "unknown")
source = pair.get("source_type", pair.get("source", "unknown"))
by_model[model] = by_model.get(model, 0) + 1
by_source[source] = by_source.get(source, 0) + 1
coverage = (with_provenance / total * 100) if total > 0 else 0
return {
"file": str(filepath),
"total": total,
"with_provenance": with_provenance,
"coverage_pct": round(coverage, 1),
"missing_by_field": missing_by_field,
"by_model": by_model,
"by_source": by_source,
}
def validate_all(files: List[str], threshold: float = 0) -> dict:
"""Validate provenance across multiple files."""
results = []
total_pairs = 0
total_with_prov = 0
for filepath in files:
result = validate_file(filepath)
results.append(result)
total_pairs += result["total"]
total_with_prov += result["with_provenance"]
overall_coverage = (total_with_prov / total_pairs * 100) if total_pairs > 0 else 0
return {
"files": results,
"total_pairs": total_pairs,
"total_with_provenance": total_with_prov,
"overall_coverage_pct": round(overall_coverage, 1),
"passes_threshold": overall_coverage >= threshold,
}
def main():
import argparse
parser = argparse.ArgumentParser(description="Validate training data provenance")
parser.add_argument("files", nargs="+", help="JSONL files to validate")
parser.add_argument("--threshold", type=float, default=0,
help="Minimum coverage percentage to pass")
parser.add_argument("--json", action="store_true", help="JSON output")
args = parser.parse_args()
# Expand globs
files = []
for pattern in args.files:
expanded = list(Path(".").glob(pattern)) if "*" in pattern else [Path(pattern)]
files.extend(str(f) for f in expanded if f.exists())
if not files:
print("No files found", file=sys.stderr)
sys.exit(1)
result = validate_all(files, args.threshold)
if args.json:
print(json.dumps(result, indent=2))
else:
print(f"\n{'='*50}")
print(" PROVENANCE VALIDATION REPORT")
print(f"{'='*50}")
print(f" Total pairs: {result['total_pairs']}")
print(f" With provenance: {result['total_with_provenance']}")
print(f" Coverage: {result['overall_coverage_pct']}%")
if args.threshold > 0:
status = "PASS" if result["passes_threshold"] else "FAIL"
print(f" Threshold: {args.threshold}% [{status}]")
print(f"\n Per file:")
for f in result["files"]:
icon = "" if f["coverage_pct"] >= args.threshold else ""
print(f" {icon} {f['file']}: {f['coverage_pct']}% ({f['with_provenance']}/{f['total']})")
print(f"{'='*50}\n")
# Exit code
if args.threshold > 0 and not result["passes_threshold"]:
sys.exit(1)
sys.exit(0)
if __name__ == "__main__":
main()