Files
timmy-config/training/training_pair_provenance.py
Alexander Whitestone 5763a148c2 feat: Add training pair provenance tracking
Adds provenance metadata to training pairs:
- source_session_id: Which session generated the pair
- model: Which model generated it
- timestamp: When it was generated
- source: Source type (curated, trajectory, etc.)
- content_hash: For deduplication

Provides filtering and reporting capabilities.

Addresses issue #691
2026-04-15 16:01:49 +00:00

282 lines
10 KiB
Python

#!/usr/bin/env python3
"""
Training Pair Provenance Tracking
Adds provenance metadata to training pairs for quality filtering and reporting.
Tracks source session, model, timestamp, and other metadata.
Usage:
python3 training_pair_provenance.py --input data/curated_dataset.jsonl --output data/curated_with_provenance.jsonl
python3 training_pair_provenance.py --input data/curated_dataset.jsonl --filter exclude_anthropic
python3 training_pair_provenance.py --input data/curated_dataset.jsonl --report
"""
import argparse
import json
import hashlib
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
class ProvenanceTracker:
"""Track provenance of training pairs."""
# Models to exclude by default (configurable)
EXCLUDED_MODELS = {"anthropic/claude-3-opus", "anthropic/claude-3-sonnet", "anthropic/claude-3-haiku"}
def __init__(self):
self.stats = {
"total_pairs": 0,
"pairs_with_provenance": 0,
"pairs_without_provenance": 0,
"by_model": {},
"by_source": {},
"excluded": 0
}
def generate_pair_id(self, pair: Dict[str, Any]) -> str:
"""Generate a unique ID for a training pair."""
# Use content hash for deduplication
content = json.dumps(pair, sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()[:16]
def add_provenance(self, pair: Dict[str, Any],
source_session_id: Optional[str] = None,
model: Optional[str] = None,
source: str = "curated") -> Dict[str, Any]:
"""Add provenance metadata to a training pair."""
# Generate pair ID if not present
if "id" not in pair:
pair["id"] = self.generate_pair_id(pair)
# Add provenance metadata
if "provenance" not in pair:
pair["provenance"] = {}
provenance = pair["provenance"]
# Source session ID
if source_session_id:
provenance["source_session_id"] = source_session_id
elif "id" in pair:
# Use existing ID as session ID
provenance["source_session_id"] = pair["id"]
# Model
if model:
provenance["model"] = model
elif "model" in pair:
# Use existing model field
provenance["model"] = pair["model"]
# Timestamp
if "timestamp" not in provenance:
provenance["timestamp"] = datetime.now(timezone.utc).isoformat()
# Source type
provenance["source"] = source
# Content hash for deduplication
if "content_hash" not in provenance:
# Hash the conversations for dedup
conversations = pair.get("conversations", [])
content_str = json.dumps(conversations, sort_keys=True)
provenance["content_hash"] = hashlib.sha256(content_str.encode()).hexdigest()[:32]
return pair
def extract_provenance_from_existing(self, pair: Dict[str, Any]) -> Dict[str, Any]:
"""Extract provenance from existing pair fields."""
provenance = {}
# Extract from existing fields
if "id" in pair:
provenance["source_session_id"] = pair["id"]
if "model" in pair:
provenance["model"] = pair["model"]
if "started_at" in pair:
provenance["timestamp"] = pair["started_at"]
# Add source
provenance["source"] = "curated"
# Add content hash
conversations = pair.get("conversations", [])
content_str = json.dumps(conversations, sort_keys=True)
provenance["content_hash"] = hashlib.sha256(content_str.encode()).hexdigest()[:32]
return provenance
def process_pair(self, pair: Dict[str, Any],
add_provenance: bool = True) -> Dict[str, Any]:
"""Process a single training pair."""
self.stats["total_pairs"] += 1
# Check if provenance already exists
if "provenance" in pair:
self.stats["pairs_with_provenance"] += 1
provenance = pair["provenance"]
else:
self.stats["pairs_without_provenance"] += 1
if add_provenance:
# Extract from existing fields
provenance = self.extract_provenance_from_existing(pair)
pair["provenance"] = provenance
else:
provenance = {}
# Update statistics
model = provenance.get("model", "unknown")
self.stats["by_model"][model] = self.stats["by_model"].get(model, 0) + 1
source = provenance.get("source", "unknown")
self.stats["by_source"][source] = self.stats["by_source"].get(source, 0) + 1
return pair
def filter_by_provenance(self, pairs: List[Dict[str, Any]],
exclude_models: Optional[List[str]] = None,
exclude_sources: Optional[List[str]] = None,
min_timestamp: Optional[str] = None,
max_timestamp: Optional[str] = None) -> List[Dict[str, Any]]:
"""Filter pairs by provenance criteria."""
if exclude_models is None:
exclude_models = list(self.EXCLUDED_MODELS)
filtered = []
for pair in pairs:
provenance = pair.get("provenance", {})
# Check model exclusion
model = provenance.get("model", "")
if model in exclude_models:
self.stats["excluded"] += 1
continue
# Check source exclusion
source = provenance.get("source", "")
if exclude_sources and source in exclude_sources:
self.stats["excluded"] += 1
continue
# Check timestamp range
timestamp = provenance.get("timestamp", "")
if min_timestamp and timestamp < min_timestamp:
self.stats["excluded"] += 1
continue
if max_timestamp and timestamp > max_timestamp:
self.stats["excluded"] += 1
continue
filtered.append(pair)
return filtered
def generate_report(self) -> str:
"""Generate a provenance report."""
report = []
report.append("=== Training Pair Provenance Report ===")
report.append(f"Total pairs: {self.stats['total_pairs']}")
report.append(f"Pairs with provenance: {self.stats['pairs_with_provenance']}")
report.append(f"Pairs without provenance: {self.stats['pairs_without_provenance']}")
report.append(f"Excluded pairs: {self.stats['excluded']}")
report.append("")
report.append("=== Pairs by Model ===")
for model, count in sorted(self.stats["by_model"].items(), key=lambda x: x[1], reverse=True):
report.append(f" {model}: {count}")
report.append("")
report.append("=== Pairs by Source ===")
for source, count in sorted(self.stats["by_source"].items(), key=lambda x: x[1], reverse=True):
report.append(f" {source}: {count}")
return "\n".join(report)
def load_jsonl(path: Path) -> List[Dict[str, Any]]:
"""Load a JSONL file."""
entries = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
entries.append(json.loads(line))
return entries
def save_jsonl(entries: List[Dict[str, Any]], path: Path):
"""Save entries to a JSONL file."""
with open(path, "w") as f:
for entry in entries:
f.write(json.dumps(entry) + "\n")
def main():
parser = argparse.ArgumentParser(description="Training Pair Provenance Tracking")
parser.add_argument("--input", required=True, help="Input JSONL file")
parser.add_argument("--output", help="Output JSONL file (with provenance added)")
parser.add_argument("--filter", choices=["exclude_anthropic", "exclude_openai", "custom"],
help="Apply filter")
parser.add_argument("--exclude-models", nargs="+", help="Models to exclude")
parser.add_argument("--exclude-sources", nargs="+", help="Sources to exclude")
parser.add_argument("--report", action="store_true", help="Generate report only")
parser.add_argument("--json", action="store_true", help="Output report as JSON")
args = parser.parse_args()
# Load input
pairs = load_jsonl(Path(args.input))
print(f"Loaded {len(pairs)} pairs from {args.input}")
# Create tracker
tracker = ProvenanceTracker()
# Process pairs
processed_pairs = []
for pair in pairs:
processed = tracker.process_pair(pair, add_provenance=True)
processed_pairs.append(processed)
# Apply filters if requested
if args.filter:
exclude_models = []
if args.filter == "exclude_anthropic":
exclude_models = list(ProvenanceTracker.EXCLUDED_MODELS)
elif args.exclude_models:
exclude_models = args.exclude_models
processed_pairs = tracker.filter_by_provenance(
processed_pairs,
exclude_models=exclude_models,
exclude_sources=args.exclude_sources
)
print(f"After filtering: {len(processed_pairs)} pairs")
# Output
if args.report:
# Generate report
report = tracker.generate_report()
if args.json:
print(json.dumps(tracker.stats, indent=2))
else:
print(report)
elif args.output:
# Save with provenance
save_jsonl(processed_pairs, Path(args.output))
print(f"Saved {len(processed_pairs)} pairs to {args.output}")
print(tracker.generate_report())
else:
# Just print report
print(tracker.generate_report())
if __name__ == "__main__":
main()