Compare commits

..

3 Commits

Author SHA1 Message Date
f05707254e docs: Add training pair provenance tracking documentation
Some checks failed
Architecture Lint / Linter Tests (pull_request) Successful in 14s
Smoke Test / smoke (pull_request) Failing after 11s
Validate Config / YAML Lint (pull_request) Failing after 11s
Validate Config / JSON Validate (pull_request) Successful in 14s
Validate Config / Python Syntax & Import Check (pull_request) Failing after 1m21s
Validate Config / Shell Script Lint (pull_request) Failing after 20s
Validate Config / Cron Syntax Check (pull_request) Successful in 9s
Validate Config / Deploy Script Dry Run (pull_request) Successful in 8s
Validate Config / Playbook Schema Validation (pull_request) Successful in 17s
PR Checklist / pr-checklist (pull_request) Failing after 3m33s
Architecture Lint / Lint Repository (pull_request) Has been cancelled
Validate Config / Python Test Suite (pull_request) Has been cancelled
Documents the new provenance tracking feature.
2026-04-15 16:03:48 +00:00
7101a0d5e5 test: Add tests for training pair provenance tracking
Comprehensive test suite for the provenance tracking functionality.
2026-04-15 16:02:58 +00:00
5763a148c2 feat: Add training pair provenance tracking
Adds provenance metadata to training pairs:
- source_session_id: Which session generated the pair
- model: Which model generated it
- timestamp: When it was generated
- source: Source type (curated, trajectory, etc.)
- content_hash: For deduplication

Provides filtering and reporting capabilities.

Addresses issue #691
2026-04-15 16:01:49 +00:00
4 changed files with 504 additions and 568 deletions

View File

@@ -1,568 +0,0 @@
#!/usr/bin/env python3
"""
orchestrator.py — Shared Pipeline Orchestrator
SQLite-backed job queue with parallel workers, token budget tracking,
checkpoint resume, rate limiting, and error retry.
All 5 pipelines use this orchestrator for consistent execution.
Usage:
python3 orchestrator.py --pipeline training_factory --jobs jobs.jsonl
python3 orchestrator.py --pipeline adversary --jobs jobs.jsonl --workers 5
python3 orchestrator.py --status
python3 orchestrator.py --resume training_factory
python3 orchestrator.py --report training_factory
"""
import json
import os
import sys
import time
import sqlite3
import hashlib
import threading
import signal
from datetime import datetime, timezone
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
DB_PATH = Path.home() / ".hermes" / "pipeline" / "orchestrator.db"
REPORT_DIR = Path.home() / ".hermes" / "pipeline" / "reports"
# ============================================================
# Data Structures
# ============================================================
@dataclass
class JobStatus:
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
RETRYING = "retrying"
SKIPPED = "skipped"
@dataclass
class PipelineStats:
"""Runtime statistics for a pipeline run."""
pipeline: str
total_jobs: int = 0
completed: int = 0
failed: int = 0
skipped: int = 0
tokens_used: int = 0
tokens_budget: int = 5_000_000
elapsed_seconds: float = 0.0
start_time: str = ""
jobs_per_minute: float = 0.0
def to_dict(self):
return {
"pipeline": self.pipeline,
"total_jobs": self.total_jobs,
"completed": self.completed,
"failed": self.failed,
"skipped": self.skipped,
"tokens_used": self.tokens_used,
"tokens_budget": self.tokens_budget,
"elapsed_seconds": round(self.elapsed_seconds, 1),
"start_time": self.start_time,
"jobs_per_minute": round(self.jobs_per_minute, 2),
}
# ============================================================
# Database
# ============================================================
def get_db():
"""Get SQLite database connection."""
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH), timeout=30, check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
_init_db(conn)
return conn
def _init_db(conn):
"""Initialize database schema."""
conn.executescript("""
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
pipeline TEXT NOT NULL,
job_key TEXT NOT NULL,
payload TEXT NOT NULL,
status TEXT DEFAULT 'pending',
attempts INTEGER DEFAULT 0,
max_attempts INTEGER DEFAULT 3,
tokens_used INTEGER DEFAULT 0,
error TEXT,
result TEXT,
checkpoint TEXT,
created_at TEXT DEFAULT (datetime('now')),
started_at TEXT,
completed_at TEXT,
UNIQUE(pipeline, job_key)
);
CREATE INDEX IF NOT EXISTS idx_jobs_pipeline_status ON jobs(pipeline, status);
CREATE INDEX IF NOT EXISTS idx_jobs_pipeline_key ON jobs(pipeline, job_key);
CREATE TABLE IF NOT EXISTS pipeline_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
pipeline TEXT NOT NULL,
started_at TEXT DEFAULT (datetime('now')),
completed_at TEXT,
total_jobs INTEGER DEFAULT 0,
completed INTEGER DEFAULT 0,
failed INTEGER DEFAULT 0,
tokens_used INTEGER DEFAULT 0,
report TEXT
);
""")
conn.commit()
# ============================================================
# Job Queue
# ============================================================
class JobQueue:
"""SQLite-backed job queue."""
def __init__(self, pipeline: str, conn=None):
self.pipeline = pipeline
self.conn = conn or get_db()
def enqueue(self, job_key: str, payload: dict, max_attempts: int = 3):
"""Add a job to the queue (skip if already exists)."""
try:
self.conn.execute(
"INSERT INTO jobs (pipeline, job_key, payload, max_attempts) VALUES (?, ?, ?, ?)",
(self.pipeline, job_key, json.dumps(payload), max_attempts),
)
self.conn.commit()
return True
except sqlite3.IntegrityError:
# Already exists — check if it needs retry
row = self.conn.execute(
"SELECT status FROM jobs WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
).fetchone()
if row and row[0] == "failed":
# Reset for retry
self.conn.execute(
"UPDATE jobs SET status='pending', attempts=0, error=NULL WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
)
self.conn.commit()
return True
return False
def enqueue_batch(self, jobs: List[dict], key_field: str = "id"):
"""Enqueue multiple jobs. Returns (added, skipped) counts."""
added = 0
skipped = 0
for job in jobs:
key = str(job.get(key_field, hashlib.md5(json.dumps(job).encode()).hexdigest()[:12]))
if self.enqueue(key, job):
added += 1
else:
skipped += 1
return added, skipped
def claim_next(self) -> Optional[dict]:
"""Claim the next pending job (atomic)."""
row = self.conn.execute(
"""UPDATE jobs SET status='running', started_at=datetime('now')
WHERE id = (
SELECT id FROM jobs WHERE pipeline=? AND status IN ('pending', 'retrying')
ORDER BY attempts ASC, created_at ASC LIMIT 1
) RETURNING *""",
(self.pipeline,),
).fetchone()
if not row:
return None
cols = [d[1] for d in self.conn.execute("PRAGMA table_info(jobs)").fetchall()]
return dict(zip(cols, row))
def complete(self, job_key: str, result: dict, tokens_used: int = 0):
"""Mark a job as completed."""
self.conn.execute(
"""UPDATE jobs SET status='completed', completed_at=datetime('now'),
result=?, tokens_used=? WHERE pipeline=? AND job_key=?""",
(json.dumps(result), tokens_used, self.pipeline, job_key),
)
self.conn.commit()
def fail(self, job_key: str, error: str, retry: bool = True):
"""Mark a job as failed, optionally retry."""
row = self.conn.execute(
"SELECT attempts, max_attempts FROM jobs WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
).fetchone()
if not row:
return
attempts, max_attempts = row
new_attempts = attempts + 1
if retry and new_attempts < max_attempts:
# Exponential backoff: 2^attempts seconds
delay = min(2 ** new_attempts, 60)
self.conn.execute(
"""UPDATE jobs SET status='retrying', attempts=?, error=?
WHERE pipeline=? AND job_key=?""",
(new_attempts, error, self.pipeline, job_key),
)
else:
self.conn.execute(
"""UPDATE jobs SET status='failed', attempts=?, error=?,
completed_at=datetime('now') WHERE pipeline=? AND job_key=?""",
(new_attempts, error, self.pipeline, job_key),
)
self.conn.commit()
def save_checkpoint(self, job_key: str, checkpoint: dict):
"""Save progress checkpoint for resume."""
self.conn.execute(
"UPDATE jobs SET checkpoint=? WHERE pipeline=? AND job_key=?",
(json.dumps(checkpoint), self.pipeline, job_key),
)
self.conn.commit()
def get_checkpoint(self, job_key: str) -> Optional[dict]:
"""Get saved checkpoint."""
row = self.conn.execute(
"SELECT checkpoint FROM jobs WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
).fetchone()
if row and row[0]:
return json.loads(row[0])
return None
def stats(self) -> dict:
"""Get queue statistics."""
rows = self.conn.execute(
"""SELECT status, COUNT(*), COALESCE(SUM(tokens_used), 0)
FROM jobs WHERE pipeline=? GROUP BY status""",
(self.pipeline,),
).fetchall()
result = {"total": 0, "tokens_used": 0}
for status, count, tokens in rows:
result[status] = count
result["total"] += count
result["tokens_used"] += tokens
return result
# ============================================================
# Orchestrator
# ============================================================
class Orchestrator:
"""
Shared orchestrator for all pipelines.
Features:
- Parallel worker pool (configurable)
- Token budget tracking
- Checkpoint resume
- Rate limiting
- Error retry with exponential backoff
- Final report generation
"""
def __init__(self, pipeline: str, workers: int = 10, token_budget: int = 5_000_000):
self.pipeline = pipeline
self.workers = workers
self.token_budget = token_budget
self.queue = JobQueue(pipeline)
self.conn = self.queue.conn
self._shutdown = False
self._stats = PipelineStats(pipeline=pipeline, tokens_budget=token_budget)
self._rate_limit_delay = 0.1 # seconds between jobs
self._response_cache: Dict[str, dict] = {}
signal.signal(signal.SIGINT, self._handle_signal)
signal.signal(signal.SIGTERM, self._handle_signal)
def _handle_signal(self, signum, frame):
"""Graceful shutdown on signal."""
print(f"\nReceived signal {signum}. Shutting down gracefully...")
self._shutdown = True
def load_jobs(self, jobs_path: str, key_field: str = "id"):
"""Load jobs from a JSONL file into the queue."""
jobs = []
with open(jobs_path) as f:
for line in f:
line = line.strip()
if line:
jobs.append(json.loads(line))
added, skipped = self.queue.enqueue_batch(jobs, key_field)
print(f"Loaded: {added} new, {skipped} existing")
def run(self, job_handler: Callable[[dict], dict] = None):
"""
Run the orchestrator. Processes all pending jobs with parallel workers.
Args:
job_handler: function(job_payload) -> dict with 'tokens_used' key
"""
start = time.time()
self._stats.start_time = datetime.now(timezone.utc).isoformat()
# Record run
self.conn.execute(
"INSERT INTO pipeline_runs (pipeline, started_at) VALUES (?, ?)",
(self.pipeline, self._stats.start_time),
)
run_id = self.conn.execute("SELECT last_insert_rowid()").fetchone()[0]
self.conn.commit()
stats = self.queue.stats()
self._stats.total_jobs = stats.get("pending", 0) + stats.get("retrying", 0)
print(f"\nPipeline: {self.pipeline}")
print(f"Jobs: {self._stats.total_jobs} pending | Workers: {self.workers} | Budget: {self.token_budget:,} tokens")
print()
if self._stats.total_jobs == 0:
print("No jobs to process.")
return
completed = 0
failed = 0
skipped = 0
tokens_used = 0
with ThreadPoolExecutor(max_workers=self.workers) as executor:
futures = {}
while not self._shutdown:
# Check token budget
if tokens_used >= self.token_budget:
print(f"Token budget exhausted ({tokens_used:,}/{self.token_budget:,})")
break
# Fill worker pool
while len(futures) < self.workers and not self._shutdown:
job = self.queue.claim_next()
if not job:
break
# Check response cache (zero-token retries)
job_key = job["job_key"]
payload = json.loads(job["payload"])
cache_key = hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()
if cache_key in self._response_cache:
result = self._response_cache[cache_key]
self.queue.complete(job_key, result, tokens_used=0)
skipped += 1
continue
# Submit to worker
future = executor.submit(self._process_job, job, job_handler)
futures[future] = job
# Rate limiting
time.sleep(self._rate_limit_delay)
if not futures:
break
# Collect results
done = []
for future in as_completed(futures, timeout=1):
job = futures[future]
try:
result = future.result()
if result.get("success"):
tokens = result.get("tokens_used", 0)
tokens_used += tokens
self.queue.complete(job["job_key"], result, tokens_used=tokens)
completed += 1
else:
error = result.get("error", "unknown error")
self.queue.fail(job["job_key"], error, retry=True)
failed += 1
except Exception as e:
self.queue.fail(job["job_key"], str(e), retry=True)
failed += 1
done.append(future)
# Progress update
total = completed + failed + skipped
if total % 10 == 0:
elapsed = time.time() - start
rate = completed / (elapsed / 60) if elapsed > 0 else 0
print(f" Progress: {total}/{self._stats.total_jobs} | "
f"completed={completed} failed={failed} | "
f"tokens={tokens_used:,} | "
f"{rate:.1f}/min")
for f in done:
del futures[f]
# Final report
elapsed = time.time() - start
self._stats.completed = completed
self._stats.failed = failed
self._stats.skipped = skipped
self._stats.tokens_used = tokens_used
self._stats.elapsed_seconds = elapsed
self._stats.jobs_per_minute = completed / (elapsed / 60) if elapsed > 0 else 0
# Save run
self.conn.execute(
"""UPDATE pipeline_runs SET completed_at=?, total_jobs=?, completed=?,
failed=?, tokens_used=?, report=? WHERE id=?""",
(datetime.now(timezone.utc).isoformat(), self._stats.total_jobs,
completed, failed, tokens_used, json.dumps(self._stats.to_dict()), run_id),
)
self.conn.commit()
# Print report
print(f"\n{'='*50}")
print(f"Pipeline: {self.pipeline}")
print(f"Completed: {completed}/{self._stats.total_jobs}")
print(f"Failed: {failed}")
print(f"Skipped (cached): {skipped}")
print(f"Tokens: {tokens_used:,}/{self.token_budget:,}")
print(f"Time: {elapsed:.1f}s ({self._stats.jobs_per_minute:.1f}/min)")
print(f"{'='*50}")
# Save report file
self._save_report()
def _process_job(self, job: dict, handler: Callable = None) -> dict:
"""Process a single job."""
payload = json.loads(job["payload"])
job_key = job["job_key"]
checkpoint = self.queue.get_checkpoint(job_key)
if handler:
try:
result = handler(payload, checkpoint=checkpoint)
return result or {"success": True, "tokens_used": 0}
except Exception as e:
return {"success": False, "error": str(e)}
else:
# Default handler: just mark as complete
return {"success": True, "tokens_used": 0}
def _save_report(self):
"""Save pipeline run report."""
REPORT_DIR.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
path = REPORT_DIR / f"{self.pipeline}_{ts}.json"
with open(path, "w") as f:
json.dump(self._stats.to_dict(), f, indent=2)
print(f"Report: {path}")
def resume(self):
"""Resume failed/retrying jobs from a previous run."""
stats = self.queue.stats()
retrying = stats.get("retrying", 0)
failed = stats.get("failed", 0)
print(f"Resume {self.pipeline}: {retrying} retrying, {failed} failed to reset")
# Reset failed jobs to pending for retry
self.conn.execute(
"UPDATE jobs SET status='pending', attempts=0 WHERE pipeline=? AND status='failed'",
(self.pipeline,),
)
self.conn.execute(
"UPDATE jobs SET status='pending' WHERE pipeline=? AND status='retrying'",
(self.pipeline,),
)
self.conn.commit()
def status(self):
"""Show pipeline status."""
stats = self.queue.stats()
print(f"\nPipeline: {self.pipeline}")
for k, v in sorted(stats.items()):
print(f" {k}: {v}")
# ============================================================
# CLI
# ============================================================
def show_all_status():
"""Show status of all pipelines."""
conn = get_db()
pipelines = conn.execute(
"SELECT DISTINCT pipeline FROM jobs ORDER BY pipeline"
).fetchall()
if not pipelines:
print("No pipelines in database.")
return
print(f"\nAll Pipeline Status")
print(f"{'='*60}")
for (pipeline,) in pipelines:
queue = JobQueue(pipeline, conn)
stats = queue.stats()
total = stats.get("total", 0)
pending = stats.get("pending", 0)
running = stats.get("running", 0)
completed = stats.get("completed", 0)
failed = stats.get("failed", 0)
tokens = stats.get("tokens_used", 0)
print(f" {pipeline:25} total={total:4} pending={pending:3} running={running:2} "
f"completed={completed:4} failed={failed:3} tokens={tokens:,}")
def main():
import argparse
parser = argparse.ArgumentParser(description="Shared Pipeline Orchestrator")
parser.add_argument("--pipeline", "-p", help="Pipeline name")
parser.add_argument("--jobs", "-j", help="Jobs JSONL file to load")
parser.add_argument("--workers", "-w", type=int, default=10, help="Parallel workers")
parser.add_argument("--budget", "-b", type=int, default=5_000_000, help="Token budget")
parser.add_argument("--status", action="store_true", help="Show status")
parser.add_argument("--resume", action="store_true", help="Resume failed jobs")
parser.add_argument("--key-field", default="id", help="Job key field name")
args = parser.parse_args()
if args.status:
if args.pipeline:
orch = Orchestrator(args.pipeline)
orch.status()
else:
show_all_status()
return
if not args.pipeline:
parser.error("--pipeline is required")
orch = Orchestrator(args.pipeline, workers=args.workers, token_budget=args.budget)
if args.jobs:
orch.load_jobs(args.jobs, key_field=args.key_field)
if args.resume:
orch.resume()
if args.jobs or args.resume:
orch.run()
if __name__ == "__main__":
main()

View File

@@ -75,3 +75,69 @@ The data (curated exemplars, preference pairs, trained weights) is proprietary.
### Key Insight
The base model's RLHF priors override LoRA on crisis/faith — the most important parts of SOUL.md. Fix: inference-time grounding (inject SOUL.md crisis protocol) + larger pure-Timmy corpus over time.
## Training Pair Provenance Tracking
Tracks the provenance of training pairs for quality filtering and reporting.
### Features
- **Metadata tracking**: Each pair gets provenance metadata:
- `source_session_id`: Which session generated the pair
- `model`: Which model generated it
- `timestamp`: When it was generated
- `source`: Source type (curated, trajectory, etc.)
- `content_hash`: For deduplication
- **Filtering**: Filter pairs by provenance criteria:
- Exclude specific models (e.g., Anthropic models)
- Exclude specific sources
- Filter by timestamp range
- **Reporting**: Generate reports showing:
- Pair count by source model
- Pair count by source type
- Exclusion statistics
### Usage
```bash
# Add provenance to existing dataset
python3 training_pair_provenance.py --input data/curated_dataset.jsonl --output data/curated_with_provenance.jsonl
# Filter out Anthropic-sourced pairs
python3 training_pair_provenance.py --input data/curated_dataset.jsonl --filter exclude_anthropic
# Generate provenance report
python3 training_pair_provenance.py --input data/curated_dataset.jsonl --report
# JSON report
python3 training_pair_provenance.py --input data/curated_dataset.jsonl --report --json
```
### Integration
The provenance tracker can be integrated into existing pipelines:
```python
from training_pair_provenance import ProvenanceTracker
tracker = ProvenanceTracker()
# Process pairs
for pair in pairs:
processed = tracker.process_pair(pair)
# Filter
filtered = tracker.filter_by_provenance(processed_pairs, exclude_models=["anthropic/claude-3-opus"])
# Report
print(tracker.generate_report())
```
### Testing
```bash
python3 -m pytest training/test_training_pair_provenance.py -v
```

View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
"""
Tests for Training Pair Provenance Tracking
"""
import json
import tempfile
from pathlib import Path
import pytest
from training_pair_provenance import ProvenanceTracker, load_jsonl, save_jsonl
class TestProvenanceTracker:
"""Test the ProvenanceTracker class."""
def test_init(self):
"""Test tracker initialization."""
tracker = ProvenanceTracker()
assert tracker.stats["total_pairs"] == 0
assert tracker.stats["pairs_with_provenance"] == 0
assert tracker.stats["pairs_without_provenance"] == 0
def test_generate_pair_id(self):
"""Test pair ID generation."""
tracker = ProvenanceTracker()
pair = {"prompt": "test", "chosen": "response", "rejected": "bad"}
id1 = tracker.generate_pair_id(pair)
id2 = tracker.generate_pair_id(pair)
# Same content should generate same ID
assert id1 == id2
assert len(id1) == 16
def test_add_provenance(self):
"""Test adding provenance to a pair."""
tracker = ProvenanceTracker()
pair = {"prompt": "test", "chosen": "response", "rejected": "bad"}
result = tracker.add_provenance(pair, source_session_id="session123", model="test-model")
assert "provenance" in result
assert result["provenance"]["source_session_id"] == "session123"
assert result["provenance"]["model"] == "test-model"
assert "timestamp" in result["provenance"]
assert result["provenance"]["source"] == "curated"
assert "content_hash" in result["provenance"]
def test_extract_provenance_from_existing(self):
"""Test extracting provenance from existing fields."""
tracker = ProvenanceTracker()
pair = {
"id": "session456",
"model": "claude-3-opus",
"started_at": "2024-01-01T00:00:00Z",
"conversations": [{"from": "human", "value": "test"}]
}
provenance = tracker.extract_provenance_from_existing(pair)
assert provenance["source_session_id"] == "session456"
assert provenance["model"] == "claude-3-opus"
assert provenance["timestamp"] == "2024-01-01T00:00:00Z"
assert provenance["source"] == "curated"
assert "content_hash" in provenance
def test_process_pair(self):
"""Test processing a pair."""
tracker = ProvenanceTracker()
pair = {"id": "test123", "model": "test-model", "conversations": []}
result = tracker.process_pair(pair)
assert tracker.stats["total_pairs"] == 1
assert tracker.stats["pairs_without_provenance"] == 1
assert "provenance" in result
def test_filter_by_provenance(self):
"""Test filtering pairs by provenance."""
tracker = ProvenanceTracker()
pairs = [
{"provenance": {"model": "anthropic/claude-3-opus"}},
{"provenance": {"model": "gpt-4"}},
{"provenance": {"model": "anthropic/claude-3-sonnet"}},
]
filtered = tracker.filter_by_provenance(pairs, exclude_models=["anthropic/claude-3-opus", "anthropic/claude-3-sonnet"])
assert len(filtered) == 1
assert filtered[0]["provenance"]["model"] == "gpt-4"
assert tracker.stats["excluded"] == 2
def test_generate_report(self):
"""Test report generation."""
tracker = ProvenanceTracker()
tracker.stats = {
"total_pairs": 10,
"pairs_with_provenance": 8,
"pairs_without_provenance": 2,
"by_model": {"gpt-4": 5, "claude-3": 3},
"by_source": {"curated": 8},
"excluded": 0
}
report = tracker.generate_report()
assert "Total pairs: 10" in report
assert "Pairs with provenance: 8" in report
assert "gpt-4: 5" in report
class TestJsonlFunctions:
"""Test JSONL load/save functions."""
def test_load_jsonl(self):
"""Test loading JSONL file."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
f.write('{"id": "1", "value": "test1"}\n')
f.write('{"id": "2", "value": "test2"}\n')
f.write('{"id": "3", "value": "test3"}\n')
temp_path = Path(f.name)
try:
entries = load_jsonl(temp_path)
assert len(entries) == 3
assert entries[0]["id"] == "1"
assert entries[2]["value"] == "test3"
finally:
temp_path.unlink()
def test_save_jsonl(self):
"""Test saving JSONL file."""
entries = [
{"id": "1", "value": "test1"},
{"id": "2", "value": "test2"}
]
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
temp_path = Path(f.name)
try:
save_jsonl(entries, temp_path)
with open(temp_path) as f:
lines = f.readlines()
assert len(lines) == 2
assert json.loads(lines[0])["id"] == "1"
assert json.loads(lines[1])["value"] == "test2"
finally:
temp_path.unlink()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,281 @@
#!/usr/bin/env python3
"""
Training Pair Provenance Tracking
Adds provenance metadata to training pairs for quality filtering and reporting.
Tracks source session, model, timestamp, and other metadata.
Usage:
python3 training_pair_provenance.py --input data/curated_dataset.jsonl --output data/curated_with_provenance.jsonl
python3 training_pair_provenance.py --input data/curated_dataset.jsonl --filter exclude_anthropic
python3 training_pair_provenance.py --input data/curated_dataset.jsonl --report
"""
import argparse
import json
import hashlib
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
class ProvenanceTracker:
"""Track provenance of training pairs."""
# Models to exclude by default (configurable)
EXCLUDED_MODELS = {"anthropic/claude-3-opus", "anthropic/claude-3-sonnet", "anthropic/claude-3-haiku"}
def __init__(self):
self.stats = {
"total_pairs": 0,
"pairs_with_provenance": 0,
"pairs_without_provenance": 0,
"by_model": {},
"by_source": {},
"excluded": 0
}
def generate_pair_id(self, pair: Dict[str, Any]) -> str:
"""Generate a unique ID for a training pair."""
# Use content hash for deduplication
content = json.dumps(pair, sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()[:16]
def add_provenance(self, pair: Dict[str, Any],
source_session_id: Optional[str] = None,
model: Optional[str] = None,
source: str = "curated") -> Dict[str, Any]:
"""Add provenance metadata to a training pair."""
# Generate pair ID if not present
if "id" not in pair:
pair["id"] = self.generate_pair_id(pair)
# Add provenance metadata
if "provenance" not in pair:
pair["provenance"] = {}
provenance = pair["provenance"]
# Source session ID
if source_session_id:
provenance["source_session_id"] = source_session_id
elif "id" in pair:
# Use existing ID as session ID
provenance["source_session_id"] = pair["id"]
# Model
if model:
provenance["model"] = model
elif "model" in pair:
# Use existing model field
provenance["model"] = pair["model"]
# Timestamp
if "timestamp" not in provenance:
provenance["timestamp"] = datetime.now(timezone.utc).isoformat()
# Source type
provenance["source"] = source
# Content hash for deduplication
if "content_hash" not in provenance:
# Hash the conversations for dedup
conversations = pair.get("conversations", [])
content_str = json.dumps(conversations, sort_keys=True)
provenance["content_hash"] = hashlib.sha256(content_str.encode()).hexdigest()[:32]
return pair
def extract_provenance_from_existing(self, pair: Dict[str, Any]) -> Dict[str, Any]:
"""Extract provenance from existing pair fields."""
provenance = {}
# Extract from existing fields
if "id" in pair:
provenance["source_session_id"] = pair["id"]
if "model" in pair:
provenance["model"] = pair["model"]
if "started_at" in pair:
provenance["timestamp"] = pair["started_at"]
# Add source
provenance["source"] = "curated"
# Add content hash
conversations = pair.get("conversations", [])
content_str = json.dumps(conversations, sort_keys=True)
provenance["content_hash"] = hashlib.sha256(content_str.encode()).hexdigest()[:32]
return provenance
def process_pair(self, pair: Dict[str, Any],
add_provenance: bool = True) -> Dict[str, Any]:
"""Process a single training pair."""
self.stats["total_pairs"] += 1
# Check if provenance already exists
if "provenance" in pair:
self.stats["pairs_with_provenance"] += 1
provenance = pair["provenance"]
else:
self.stats["pairs_without_provenance"] += 1
if add_provenance:
# Extract from existing fields
provenance = self.extract_provenance_from_existing(pair)
pair["provenance"] = provenance
else:
provenance = {}
# Update statistics
model = provenance.get("model", "unknown")
self.stats["by_model"][model] = self.stats["by_model"].get(model, 0) + 1
source = provenance.get("source", "unknown")
self.stats["by_source"][source] = self.stats["by_source"].get(source, 0) + 1
return pair
def filter_by_provenance(self, pairs: List[Dict[str, Any]],
exclude_models: Optional[List[str]] = None,
exclude_sources: Optional[List[str]] = None,
min_timestamp: Optional[str] = None,
max_timestamp: Optional[str] = None) -> List[Dict[str, Any]]:
"""Filter pairs by provenance criteria."""
if exclude_models is None:
exclude_models = list(self.EXCLUDED_MODELS)
filtered = []
for pair in pairs:
provenance = pair.get("provenance", {})
# Check model exclusion
model = provenance.get("model", "")
if model in exclude_models:
self.stats["excluded"] += 1
continue
# Check source exclusion
source = provenance.get("source", "")
if exclude_sources and source in exclude_sources:
self.stats["excluded"] += 1
continue
# Check timestamp range
timestamp = provenance.get("timestamp", "")
if min_timestamp and timestamp < min_timestamp:
self.stats["excluded"] += 1
continue
if max_timestamp and timestamp > max_timestamp:
self.stats["excluded"] += 1
continue
filtered.append(pair)
return filtered
def generate_report(self) -> str:
"""Generate a provenance report."""
report = []
report.append("=== Training Pair Provenance Report ===")
report.append(f"Total pairs: {self.stats['total_pairs']}")
report.append(f"Pairs with provenance: {self.stats['pairs_with_provenance']}")
report.append(f"Pairs without provenance: {self.stats['pairs_without_provenance']}")
report.append(f"Excluded pairs: {self.stats['excluded']}")
report.append("")
report.append("=== Pairs by Model ===")
for model, count in sorted(self.stats["by_model"].items(), key=lambda x: x[1], reverse=True):
report.append(f" {model}: {count}")
report.append("")
report.append("=== Pairs by Source ===")
for source, count in sorted(self.stats["by_source"].items(), key=lambda x: x[1], reverse=True):
report.append(f" {source}: {count}")
return "\n".join(report)
def load_jsonl(path: Path) -> List[Dict[str, Any]]:
"""Load a JSONL file."""
entries = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
entries.append(json.loads(line))
return entries
def save_jsonl(entries: List[Dict[str, Any]], path: Path):
"""Save entries to a JSONL file."""
with open(path, "w") as f:
for entry in entries:
f.write(json.dumps(entry) + "\n")
def main():
parser = argparse.ArgumentParser(description="Training Pair Provenance Tracking")
parser.add_argument("--input", required=True, help="Input JSONL file")
parser.add_argument("--output", help="Output JSONL file (with provenance added)")
parser.add_argument("--filter", choices=["exclude_anthropic", "exclude_openai", "custom"],
help="Apply filter")
parser.add_argument("--exclude-models", nargs="+", help="Models to exclude")
parser.add_argument("--exclude-sources", nargs="+", help="Sources to exclude")
parser.add_argument("--report", action="store_true", help="Generate report only")
parser.add_argument("--json", action="store_true", help="Output report as JSON")
args = parser.parse_args()
# Load input
pairs = load_jsonl(Path(args.input))
print(f"Loaded {len(pairs)} pairs from {args.input}")
# Create tracker
tracker = ProvenanceTracker()
# Process pairs
processed_pairs = []
for pair in pairs:
processed = tracker.process_pair(pair, add_provenance=True)
processed_pairs.append(processed)
# Apply filters if requested
if args.filter:
exclude_models = []
if args.filter == "exclude_anthropic":
exclude_models = list(ProvenanceTracker.EXCLUDED_MODELS)
elif args.exclude_models:
exclude_models = args.exclude_models
processed_pairs = tracker.filter_by_provenance(
processed_pairs,
exclude_models=exclude_models,
exclude_sources=args.exclude_sources
)
print(f"After filtering: {len(processed_pairs)} pairs")
# Output
if args.report:
# Generate report
report = tracker.generate_report()
if args.json:
print(json.dumps(tracker.stats, indent=2))
else:
print(report)
elif args.output:
# Save with provenance
save_jsonl(processed_pairs, Path(args.output))
print(f"Saved {len(processed_pairs)} pairs to {args.output}")
print(tracker.generate_report())
else:
# Just print report
print(tracker.generate_report())
if __name__ == "__main__":
main()