|
|
|
|
@@ -0,0 +1,260 @@
|
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
"""
|
|
|
|
|
[PROVENANCE] Training Pair Provenance Tracker
|
|
|
|
|
Part of the Timmy Foundation tooling.
|
|
|
|
|
|
|
|
|
|
Adds, filters, and reports provenance metadata for JSONL training pairs.
|
|
|
|
|
Tracks source_session_id, model, and timestamp for quality auditing.
|
|
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
# Tag pairs with provenance
|
|
|
|
|
python3 scripts/training_provenance.py tag input.jsonl -o tagged.jsonl \
|
|
|
|
|
--session abc123 --model nous/hermes-3
|
|
|
|
|
|
|
|
|
|
# Filter by model (exclude Anthropic-sourced)
|
|
|
|
|
python3 scripts/training_provenance.py filter input.jsonl -o filtered.jsonl \
|
|
|
|
|
--exclude-model anthropic
|
|
|
|
|
|
|
|
|
|
# Report: pair count by source model
|
|
|
|
|
python3 scripts/training_provenance.py report input.jsonl
|
|
|
|
|
|
|
|
|
|
# Pipe support
|
|
|
|
|
cat pairs.jsonl | python3 scripts/training_provenance.py report -
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import sys
|
|
|
|
|
import json
|
|
|
|
|
import argparse
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
from collections import Counter
|
|
|
|
|
from typing import Dict, Any, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROVENANCE_KEYS = ["source_session_id", "source_model", "source_timestamp"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tag_pair(pair: Dict[str, Any], session_id: Optional[str] = None,
|
|
|
|
|
model: Optional[str] = None) -> Dict[str, Any]:
|
|
|
|
|
"""Add provenance metadata to a training pair."""
|
|
|
|
|
meta = pair.get("_provenance", {})
|
|
|
|
|
|
|
|
|
|
if session_id:
|
|
|
|
|
meta["source_session_id"] = session_id
|
|
|
|
|
if model:
|
|
|
|
|
meta["source_model"] = model
|
|
|
|
|
meta["source_timestamp"] = datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
|
|
|
|
if meta:
|
|
|
|
|
pair["_provenance"] = meta
|
|
|
|
|
|
|
|
|
|
return pair
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def filter_pairs(input_path: str, output_path: str,
|
|
|
|
|
include_models: Optional[list] = None,
|
|
|
|
|
exclude_models: Optional[list] = None,
|
|
|
|
|
min_session_age: Optional[str] = None) -> Dict[str, Any]:
|
|
|
|
|
"""Filter pairs by provenance metadata."""
|
|
|
|
|
kept = []
|
|
|
|
|
removed = []
|
|
|
|
|
errors = 0
|
|
|
|
|
|
|
|
|
|
source = sys.stdin if input_path == "-" else open(input_path, "r")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
for line in source:
|
|
|
|
|
line = line.strip()
|
|
|
|
|
if not line:
|
|
|
|
|
continue
|
|
|
|
|
try:
|
|
|
|
|
pair = json.loads(line)
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
errors += 1
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
prov = pair.get("_provenance", {})
|
|
|
|
|
model = prov.get("source_model", "unknown")
|
|
|
|
|
|
|
|
|
|
should_keep = True
|
|
|
|
|
|
|
|
|
|
if include_models:
|
|
|
|
|
should_keep = should_keep and model in include_models
|
|
|
|
|
|
|
|
|
|
if exclude_models:
|
|
|
|
|
should_keep = should_keep and model not in exclude_models
|
|
|
|
|
|
|
|
|
|
if should_keep:
|
|
|
|
|
kept.append(pair)
|
|
|
|
|
else:
|
|
|
|
|
removed.append(pair)
|
|
|
|
|
finally:
|
|
|
|
|
if source is not sys.stdin:
|
|
|
|
|
source.close()
|
|
|
|
|
|
|
|
|
|
# Write output
|
|
|
|
|
if output_path:
|
|
|
|
|
out = sys.stdout if output_path == "-" else open(output_path, "w")
|
|
|
|
|
try:
|
|
|
|
|
for pair in kept:
|
|
|
|
|
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
|
|
|
|
finally:
|
|
|
|
|
if out is not sys.stdin:
|
|
|
|
|
out.close()
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"total": len(kept) + len(removed),
|
|
|
|
|
"kept": len(kept),
|
|
|
|
|
"filtered_out": len(removed),
|
|
|
|
|
"errors": errors,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def report(input_path: str) -> Dict[str, Any]:
|
|
|
|
|
"""Report pair counts by source model and session."""
|
|
|
|
|
model_counts = Counter()
|
|
|
|
|
session_counts = Counter()
|
|
|
|
|
tagged = 0
|
|
|
|
|
untagged = 0
|
|
|
|
|
total = 0
|
|
|
|
|
errors = 0
|
|
|
|
|
|
|
|
|
|
source = sys.stdin if input_path == "-" else open(input_path, "r")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
for line in source:
|
|
|
|
|
line = line.strip()
|
|
|
|
|
if not line:
|
|
|
|
|
continue
|
|
|
|
|
try:
|
|
|
|
|
pair = json.loads(line)
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
errors += 1
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
total += 1
|
|
|
|
|
prov = pair.get("_provenance", {})
|
|
|
|
|
|
|
|
|
|
if prov:
|
|
|
|
|
tagged += 1
|
|
|
|
|
model = prov.get("source_model", "unknown")
|
|
|
|
|
session = prov.get("source_session_id", "unknown")
|
|
|
|
|
model_counts[model] += 1
|
|
|
|
|
session_counts[session] += 1
|
|
|
|
|
else:
|
|
|
|
|
untagged += 1
|
|
|
|
|
finally:
|
|
|
|
|
if source is not sys.stdin:
|
|
|
|
|
source.close()
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"total": total,
|
|
|
|
|
"tagged": tagged,
|
|
|
|
|
"untagged": untagged,
|
|
|
|
|
"tag_rate": round(tagged / max(total, 1) * 100, 1),
|
|
|
|
|
"by_model": dict(model_counts.most_common(20)),
|
|
|
|
|
"by_session": dict(session_counts.most_common(10)),
|
|
|
|
|
"errors": errors,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stamp_command(input_path: str, output_path: str,
|
|
|
|
|
session_id: Optional[str], model: Optional[str]) -> Dict[str, Any]:
|
|
|
|
|
"""Tag all pairs in a file with provenance metadata."""
|
|
|
|
|
tagged = 0
|
|
|
|
|
skipped = 0
|
|
|
|
|
errors = 0
|
|
|
|
|
|
|
|
|
|
source = sys.stdin if input_path == "-" else open(input_path, "r")
|
|
|
|
|
out = sys.stdout if output_path == "-" else open(output_path, "w")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
for line in source:
|
|
|
|
|
line = line.strip()
|
|
|
|
|
if not line:
|
|
|
|
|
continue
|
|
|
|
|
try:
|
|
|
|
|
pair = json.loads(line)
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
errors += 1
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Skip if already tagged with same model
|
|
|
|
|
existing = pair.get("_provenance", {})
|
|
|
|
|
if existing.get("source_model") == model and existing.get("source_session_id") == session_id:
|
|
|
|
|
skipped += 1
|
|
|
|
|
out.write(line + "\n")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
pair = tag_pair(pair, session_id=session_id, model=model)
|
|
|
|
|
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
|
|
|
|
tagged += 1
|
|
|
|
|
finally:
|
|
|
|
|
if source is not sys.stdin:
|
|
|
|
|
source.close()
|
|
|
|
|
if out is not sys.stdin:
|
|
|
|
|
out.close()
|
|
|
|
|
|
|
|
|
|
return {"tagged": tagged, "skipped": skipped, "errors": errors}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
parser = argparse.ArgumentParser(description="Training pair provenance tracking")
|
|
|
|
|
sub = parser.add_subparsers(dest="command", required=True)
|
|
|
|
|
|
|
|
|
|
# tag subcommand
|
|
|
|
|
tag_p = sub.add_parser("tag", help="Tag pairs with provenance metadata")
|
|
|
|
|
tag_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
|
|
|
|
tag_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
|
|
|
|
|
tag_p.add_argument("--session", help="Source session ID")
|
|
|
|
|
tag_p.add_argument("--model", help="Source model name")
|
|
|
|
|
|
|
|
|
|
# filter subcommand
|
|
|
|
|
filt_p = sub.add_parser("filter", help="Filter pairs by provenance")
|
|
|
|
|
filt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
|
|
|
|
filt_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
|
|
|
|
|
filt_p.add_argument("--include-model", action="append", help="Only include these models")
|
|
|
|
|
filt_p.add_argument("--exclude-model", action="append", help="Exclude these models")
|
|
|
|
|
|
|
|
|
|
# report subcommand
|
|
|
|
|
rpt_p = sub.add_parser("report", help="Report provenance statistics")
|
|
|
|
|
rpt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
if args.command == "tag":
|
|
|
|
|
result = stamp_command(args.input, args.output, args.session, args.model)
|
|
|
|
|
print(f"Tagged: {result['tagged']} Skipped: {result['skipped']} Errors: {result['errors']}", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
elif args.command == "filter":
|
|
|
|
|
result = filter_pairs(
|
|
|
|
|
args.input, args.output,
|
|
|
|
|
include_models=args.include_model,
|
|
|
|
|
exclude_models=args.exclude_model,
|
|
|
|
|
)
|
|
|
|
|
print(f"Total: {result['total']} Kept: {result['kept']} Filtered: {result['filtered_out']}", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
elif args.command == "report":
|
|
|
|
|
result = report(args.input)
|
|
|
|
|
print(f"Training Pair Provenance Report", file=sys.stderr)
|
|
|
|
|
print(f"{'='*40}", file=sys.stderr)
|
|
|
|
|
print(f"Total pairs: {result['total']}", file=sys.stderr)
|
|
|
|
|
print(f"Tagged: {result['tagged']} ({result['tag_rate']}%)", file=sys.stderr)
|
|
|
|
|
print(f"Untagged: {result['untagged']}", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
if result['by_model']:
|
|
|
|
|
print(f"\nBy source model:", file=sys.stderr)
|
|
|
|
|
for model, count in result['by_model'].items():
|
|
|
|
|
print(f" {model}: {count}", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
if result['by_session']:
|
|
|
|
|
print(f"\nBy source session (top 10):", file=sys.stderr)
|
|
|
|
|
for session, count in result['by_session'].items():
|
|
|
|
|
session_short = session[:12] + "..." if len(session) > 12 else session
|
|
|
|
|
print(f" {session_short}: {count}", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
# Output JSON to stdout
|
|
|
|
|
print(json.dumps(result, indent=2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|