""" provenance.py — Training pair provenance tracking. Adds metadata to training pairs for quality filtering and lineage tracking. Every pair gets: source_session_id, model, timestamp, source_type. Usage: from training.provenance import add_provenance, validate_provenance, provenance_stats # Add provenance to a pair pair = add_provenance(pair, session_id="abc123", model="mimo-v2-pro") # Validate provenance on a batch issues = validate_provenance(pairs) # Get statistics stats = provenance_stats(pairs) """ import json import os import time from datetime import datetime, timezone from typing import Dict, List, Any, Optional from collections import Counter REQUIRED_FIELDS = ["source_session_id", "model", "timestamp"] def add_provenance(entry: dict, session_id: str = None, model: str = None, source_type: str = "generated", **extra) -> dict: """Add provenance metadata to a training pair.""" entry = dict(entry) # copy entry["source_session_id"] = session_id or "unknown" entry["model"] = model or "unknown" entry["timestamp"] = entry.get("timestamp") or datetime.now(timezone.utc).isoformat() entry["source_type"] = source_type # generated, curated, augmented, manual for k, v in extra.items(): entry[f"provenance_{k}"] = v return entry def extract_provenance_from_trajectory(trajectory: dict) -> dict: """Extract provenance info from a hermes trajectory file.""" return { "source_session_id": trajectory.get("session_id", trajectory.get("id", "unknown")), "model": trajectory.get("model", "unknown"), "timestamp": trajectory.get("started_at", trajectory.get("timestamp", "")), "source_type": "trajectory", "provider": trajectory.get("provider", ""), "message_count": trajectory.get("message_count", 0), } def validate_provenance(pairs: List[dict]) -> dict: """Validate provenance metadata on training pairs. Returns dict with: total, valid, missing_fields, by_field """ results = { "total": len(pairs), "valid": 0, "invalid": 0, "missing_fields": {}, "by_model": {}, "by_source": {}, "issues": [], } for i, pair in enumerate(pairs): missing = [f for f in REQUIRED_FIELDS if f not in pair or not pair[f]] if missing: results["invalid"] += 1 results["issues"].append({"index": i, "missing": missing}) for f in missing: results["missing_fields"][f] = results["missing_fields"].get(f, 0) + 1 else: results["valid"] += 1 model = pair.get("model", "unknown") source = pair.get("source_type", "unknown") results["by_model"][model] = results["by_model"].get(model, 0) + 1 results["by_source"][source] = results["by_source"].get(source, 0) + 1 return results def provenance_stats(pairs: List[dict]) -> dict: """Get provenance statistics for a set of pairs.""" models = Counter(p.get("model", "unknown") for p in pairs) sources = Counter(p.get("source_type", "unknown") for p in pairs) with_session = sum(1 for p in pairs if p.get("source_session_id", "unknown") != "unknown") with_model = sum(1 for p in pairs if p.get("model", "unknown") != "unknown") return { "total": len(pairs), "with_session_id": with_session, "with_model": with_model, "coverage_session": round(with_session / max(len(pairs), 1) * 100, 1), "coverage_model": round(with_model / max(len(pairs), 1) * 100, 1), "by_model": dict(models.most_common(20)), "by_source": dict(sources.most_common()), } def backfill_provenance(input_path: str, output_path: str = None, default_model: str = "unknown") -> dict: """Add provenance to existing pairs that lack it.""" if output_path is None: output_path = input_path.replace(".jsonl", "_provenance.jsonl") pairs = [] with open(input_path) as f: for line in f: if line.strip(): pairs.append(json.loads(line)) added = 0 with open(output_path, "w") as f: for pair in pairs: if "source_session_id" not in pair: pair = add_provenance(pair, model=default_model, source_type="backfill") added += 1 f.write(json.dumps(pair, ensure_ascii=False) + "\n") stats = provenance_stats(pairs) print(f"Backfill: {added} pairs annotated, {len(pairs) - added} already had provenance") print(f"Coverage: {stats['coverage_session']}% session, {stats['coverage_model']}% model") return stats def filter_by_provenance(pairs: List[dict], exclude_models: list = None, exclude_sources: list = None) -> List[dict]: """Filter pairs by provenance metadata.""" if exclude_models is None: exclude_models = [] if exclude_sources is None: exclude_sources = [] filtered = [] excluded = 0 for p in pairs: model = p.get("model", "") source = p.get("source_type", "") if model in exclude_models or source in exclude_sources: excluded += 1 else: filtered.append(p) print(f"Filtered: {len(filtered)} kept, {excluded} excluded") return filtered