Some checks failed
Architecture Lint / Linter Tests (pull_request) Has been cancelled
Architecture Lint / Lint Repository (pull_request) Has been cancelled
PR Checklist / pr-checklist (pull_request) Has been cancelled
Validate Config / Playbook Schema Validation (pull_request) Has been cancelled
Validate Training Data / validate (pull_request) Has been cancelled
Smoke Test / smoke (pull_request) Has been cancelled
Validate Config / YAML Lint (pull_request) Has been cancelled
Validate Config / JSON Validate (pull_request) Has been cancelled
Validate Config / Python Syntax & Import Check (pull_request) Has been cancelled
Validate Config / Python Test Suite (pull_request) Has been cancelled
Validate Config / Shell Script Lint (pull_request) Has been cancelled
Validate Config / Cron Syntax Check (pull_request) Has been cancelled
Validate Config / Deploy Script Dry Run (pull_request) Has been cancelled
Provenance module for tracking source of every training pair. training/provenance.py (151 lines): - add_provenance(): add metadata to pairs - validate_provenance(): check required fields - provenance_stats(): coverage and distribution - backfill_provenance(): annotate existing pairs - filter_by_provenance(): exclude by model/source - extract_provenance_from_trajectory(): hermes integration Required fields: source_session_id, model, timestamp Closes #752
152 lines
5.3 KiB
Python
152 lines
5.3 KiB
Python
"""
|
|
provenance.py — Training pair provenance tracking.
|
|
|
|
Adds metadata to training pairs for quality filtering and lineage tracking.
|
|
Every pair gets: source_session_id, model, timestamp, source_type.
|
|
|
|
Usage:
|
|
from training.provenance import add_provenance, validate_provenance, provenance_stats
|
|
|
|
# Add provenance to a pair
|
|
pair = add_provenance(pair, session_id="abc123", model="mimo-v2-pro")
|
|
|
|
# Validate provenance on a batch
|
|
issues = validate_provenance(pairs)
|
|
|
|
# Get statistics
|
|
stats = provenance_stats(pairs)
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Dict, List, Any, Optional
|
|
from collections import Counter
|
|
|
|
|
|
REQUIRED_FIELDS = ["source_session_id", "model", "timestamp"]
|
|
|
|
|
|
def add_provenance(entry: dict, session_id: str = None, model: str = None,
|
|
source_type: str = "generated", **extra) -> dict:
|
|
"""Add provenance metadata to a training pair."""
|
|
entry = dict(entry) # copy
|
|
entry["source_session_id"] = session_id or "unknown"
|
|
entry["model"] = model or "unknown"
|
|
entry["timestamp"] = entry.get("timestamp") or datetime.now(timezone.utc).isoformat()
|
|
entry["source_type"] = source_type # generated, curated, augmented, manual
|
|
for k, v in extra.items():
|
|
entry[f"provenance_{k}"] = v
|
|
return entry
|
|
|
|
|
|
def extract_provenance_from_trajectory(trajectory: dict) -> dict:
|
|
"""Extract provenance info from a hermes trajectory file."""
|
|
return {
|
|
"source_session_id": trajectory.get("session_id", trajectory.get("id", "unknown")),
|
|
"model": trajectory.get("model", "unknown"),
|
|
"timestamp": trajectory.get("started_at", trajectory.get("timestamp", "")),
|
|
"source_type": "trajectory",
|
|
"provider": trajectory.get("provider", ""),
|
|
"message_count": trajectory.get("message_count", 0),
|
|
}
|
|
|
|
|
|
def validate_provenance(pairs: List[dict]) -> dict:
|
|
"""Validate provenance metadata on training pairs.
|
|
|
|
Returns dict with: total, valid, missing_fields, by_field
|
|
"""
|
|
results = {
|
|
"total": len(pairs),
|
|
"valid": 0,
|
|
"invalid": 0,
|
|
"missing_fields": {},
|
|
"by_model": {},
|
|
"by_source": {},
|
|
"issues": [],
|
|
}
|
|
|
|
for i, pair in enumerate(pairs):
|
|
missing = [f for f in REQUIRED_FIELDS if f not in pair or not pair[f]]
|
|
if missing:
|
|
results["invalid"] += 1
|
|
results["issues"].append({"index": i, "missing": missing})
|
|
for f in missing:
|
|
results["missing_fields"][f] = results["missing_fields"].get(f, 0) + 1
|
|
else:
|
|
results["valid"] += 1
|
|
model = pair.get("model", "unknown")
|
|
source = pair.get("source_type", "unknown")
|
|
results["by_model"][model] = results["by_model"].get(model, 0) + 1
|
|
results["by_source"][source] = results["by_source"].get(source, 0) + 1
|
|
|
|
return results
|
|
|
|
|
|
def provenance_stats(pairs: List[dict]) -> dict:
|
|
"""Get provenance statistics for a set of pairs."""
|
|
models = Counter(p.get("model", "unknown") for p in pairs)
|
|
sources = Counter(p.get("source_type", "unknown") for p in pairs)
|
|
with_session = sum(1 for p in pairs if p.get("source_session_id", "unknown") != "unknown")
|
|
with_model = sum(1 for p in pairs if p.get("model", "unknown") != "unknown")
|
|
|
|
return {
|
|
"total": len(pairs),
|
|
"with_session_id": with_session,
|
|
"with_model": with_model,
|
|
"coverage_session": round(with_session / max(len(pairs), 1) * 100, 1),
|
|
"coverage_model": round(with_model / max(len(pairs), 1) * 100, 1),
|
|
"by_model": dict(models.most_common(20)),
|
|
"by_source": dict(sources.most_common()),
|
|
}
|
|
|
|
|
|
def backfill_provenance(input_path: str, output_path: str = None,
|
|
default_model: str = "unknown") -> dict:
|
|
"""Add provenance to existing pairs that lack it."""
|
|
if output_path is None:
|
|
output_path = input_path.replace(".jsonl", "_provenance.jsonl")
|
|
|
|
pairs = []
|
|
with open(input_path) as f:
|
|
for line in f:
|
|
if line.strip():
|
|
pairs.append(json.loads(line))
|
|
|
|
added = 0
|
|
with open(output_path, "w") as f:
|
|
for pair in pairs:
|
|
if "source_session_id" not in pair:
|
|
pair = add_provenance(pair, model=default_model, source_type="backfill")
|
|
added += 1
|
|
f.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
|
|
|
stats = provenance_stats(pairs)
|
|
print(f"Backfill: {added} pairs annotated, {len(pairs) - added} already had provenance")
|
|
print(f"Coverage: {stats['coverage_session']}% session, {stats['coverage_model']}% model")
|
|
return stats
|
|
|
|
|
|
def filter_by_provenance(pairs: List[dict], exclude_models: list = None,
|
|
exclude_sources: list = None) -> List[dict]:
|
|
"""Filter pairs by provenance metadata."""
|
|
if exclude_models is None:
|
|
exclude_models = []
|
|
if exclude_sources is None:
|
|
exclude_sources = []
|
|
|
|
filtered = []
|
|
excluded = 0
|
|
for p in pairs:
|
|
model = p.get("model", "")
|
|
source = p.get("source_type", "")
|
|
if model in exclude_models or source in exclude_sources:
|
|
excluded += 1
|
|
else:
|
|
filtered.append(p)
|
|
|
|
print(f"Filtered: {len(filtered)} kept, {excluded} excluded")
|
|
return filtered
|