Compare commits

..

2 Commits

Author SHA1 Message Date
ae063a8c71 test: training pair provenance tracker (#691)
Some checks failed
Architecture Lint / Linter Tests (pull_request) Successful in 24s
PR Checklist / pr-checklist (pull_request) Failing after 3m29s
Smoke Test / smoke (pull_request) Failing after 18s
Validate Config / YAML Lint (pull_request) Failing after 10s
Validate Config / JSON Validate (pull_request) Successful in 12s
Validate Config / Python Syntax & Import Check (pull_request) Failing after 1m15s
Validate Config / Shell Script Lint (pull_request) Failing after 31s
Validate Config / Cron Syntax Check (pull_request) Successful in 6s
Validate Config / Deploy Script Dry Run (pull_request) Successful in 3s
Validate Config / Playbook Schema Validation (pull_request) Successful in 11s
Architecture Lint / Lint Repository (pull_request) Has been cancelled
Validate Config / Python Test Suite (pull_request) Has been cancelled
Comprehensive tests for tag/filter/report commands.
Covers: tagging, dedup, skip logic, error handling,
filtering by model, reporting, and full pipeline.

Closes #691
2026-04-15 14:54:53 +00:00
41afbc2ca9 feat: add training pair provenance tracker (#691)
Tag, filter, and report provenance metadata for JSONL training pairs.
Tracks source_session_id, model, and timestamp per pair.

Closes #691
2026-04-15 14:53:11 +00:00
3 changed files with 629 additions and 568 deletions

View File

@@ -1,568 +0,0 @@
#!/usr/bin/env python3
"""
orchestrator.py — Shared Pipeline Orchestrator
SQLite-backed job queue with parallel workers, token budget tracking,
checkpoint resume, rate limiting, and error retry.
All 5 pipelines use this orchestrator for consistent execution.
Usage:
python3 orchestrator.py --pipeline training_factory --jobs jobs.jsonl
python3 orchestrator.py --pipeline adversary --jobs jobs.jsonl --workers 5
python3 orchestrator.py --status
python3 orchestrator.py --resume training_factory
python3 orchestrator.py --report training_factory
"""
import json
import os
import sys
import time
import sqlite3
import hashlib
import threading
import signal
from datetime import datetime, timezone
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
DB_PATH = Path.home() / ".hermes" / "pipeline" / "orchestrator.db"
REPORT_DIR = Path.home() / ".hermes" / "pipeline" / "reports"
# ============================================================
# Data Structures
# ============================================================
@dataclass
class JobStatus:
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
RETRYING = "retrying"
SKIPPED = "skipped"
@dataclass
class PipelineStats:
"""Runtime statistics for a pipeline run."""
pipeline: str
total_jobs: int = 0
completed: int = 0
failed: int = 0
skipped: int = 0
tokens_used: int = 0
tokens_budget: int = 5_000_000
elapsed_seconds: float = 0.0
start_time: str = ""
jobs_per_minute: float = 0.0
def to_dict(self):
return {
"pipeline": self.pipeline,
"total_jobs": self.total_jobs,
"completed": self.completed,
"failed": self.failed,
"skipped": self.skipped,
"tokens_used": self.tokens_used,
"tokens_budget": self.tokens_budget,
"elapsed_seconds": round(self.elapsed_seconds, 1),
"start_time": self.start_time,
"jobs_per_minute": round(self.jobs_per_minute, 2),
}
# ============================================================
# Database
# ============================================================
def get_db():
"""Get SQLite database connection."""
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH), timeout=30, check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
_init_db(conn)
return conn
def _init_db(conn):
"""Initialize database schema."""
conn.executescript("""
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
pipeline TEXT NOT NULL,
job_key TEXT NOT NULL,
payload TEXT NOT NULL,
status TEXT DEFAULT 'pending',
attempts INTEGER DEFAULT 0,
max_attempts INTEGER DEFAULT 3,
tokens_used INTEGER DEFAULT 0,
error TEXT,
result TEXT,
checkpoint TEXT,
created_at TEXT DEFAULT (datetime('now')),
started_at TEXT,
completed_at TEXT,
UNIQUE(pipeline, job_key)
);
CREATE INDEX IF NOT EXISTS idx_jobs_pipeline_status ON jobs(pipeline, status);
CREATE INDEX IF NOT EXISTS idx_jobs_pipeline_key ON jobs(pipeline, job_key);
CREATE TABLE IF NOT EXISTS pipeline_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
pipeline TEXT NOT NULL,
started_at TEXT DEFAULT (datetime('now')),
completed_at TEXT,
total_jobs INTEGER DEFAULT 0,
completed INTEGER DEFAULT 0,
failed INTEGER DEFAULT 0,
tokens_used INTEGER DEFAULT 0,
report TEXT
);
""")
conn.commit()
# ============================================================
# Job Queue
# ============================================================
class JobQueue:
"""SQLite-backed job queue."""
def __init__(self, pipeline: str, conn=None):
self.pipeline = pipeline
self.conn = conn or get_db()
def enqueue(self, job_key: str, payload: dict, max_attempts: int = 3):
"""Add a job to the queue (skip if already exists)."""
try:
self.conn.execute(
"INSERT INTO jobs (pipeline, job_key, payload, max_attempts) VALUES (?, ?, ?, ?)",
(self.pipeline, job_key, json.dumps(payload), max_attempts),
)
self.conn.commit()
return True
except sqlite3.IntegrityError:
# Already exists — check if it needs retry
row = self.conn.execute(
"SELECT status FROM jobs WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
).fetchone()
if row and row[0] == "failed":
# Reset for retry
self.conn.execute(
"UPDATE jobs SET status='pending', attempts=0, error=NULL WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
)
self.conn.commit()
return True
return False
def enqueue_batch(self, jobs: List[dict], key_field: str = "id"):
"""Enqueue multiple jobs. Returns (added, skipped) counts."""
added = 0
skipped = 0
for job in jobs:
key = str(job.get(key_field, hashlib.md5(json.dumps(job).encode()).hexdigest()[:12]))
if self.enqueue(key, job):
added += 1
else:
skipped += 1
return added, skipped
def claim_next(self) -> Optional[dict]:
"""Claim the next pending job (atomic)."""
row = self.conn.execute(
"""UPDATE jobs SET status='running', started_at=datetime('now')
WHERE id = (
SELECT id FROM jobs WHERE pipeline=? AND status IN ('pending', 'retrying')
ORDER BY attempts ASC, created_at ASC LIMIT 1
) RETURNING *""",
(self.pipeline,),
).fetchone()
if not row:
return None
cols = [d[1] for d in self.conn.execute("PRAGMA table_info(jobs)").fetchall()]
return dict(zip(cols, row))
def complete(self, job_key: str, result: dict, tokens_used: int = 0):
"""Mark a job as completed."""
self.conn.execute(
"""UPDATE jobs SET status='completed', completed_at=datetime('now'),
result=?, tokens_used=? WHERE pipeline=? AND job_key=?""",
(json.dumps(result), tokens_used, self.pipeline, job_key),
)
self.conn.commit()
def fail(self, job_key: str, error: str, retry: bool = True):
"""Mark a job as failed, optionally retry."""
row = self.conn.execute(
"SELECT attempts, max_attempts FROM jobs WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
).fetchone()
if not row:
return
attempts, max_attempts = row
new_attempts = attempts + 1
if retry and new_attempts < max_attempts:
# Exponential backoff: 2^attempts seconds
delay = min(2 ** new_attempts, 60)
self.conn.execute(
"""UPDATE jobs SET status='retrying', attempts=?, error=?
WHERE pipeline=? AND job_key=?""",
(new_attempts, error, self.pipeline, job_key),
)
else:
self.conn.execute(
"""UPDATE jobs SET status='failed', attempts=?, error=?,
completed_at=datetime('now') WHERE pipeline=? AND job_key=?""",
(new_attempts, error, self.pipeline, job_key),
)
self.conn.commit()
def save_checkpoint(self, job_key: str, checkpoint: dict):
"""Save progress checkpoint for resume."""
self.conn.execute(
"UPDATE jobs SET checkpoint=? WHERE pipeline=? AND job_key=?",
(json.dumps(checkpoint), self.pipeline, job_key),
)
self.conn.commit()
def get_checkpoint(self, job_key: str) -> Optional[dict]:
"""Get saved checkpoint."""
row = self.conn.execute(
"SELECT checkpoint FROM jobs WHERE pipeline=? AND job_key=?",
(self.pipeline, job_key),
).fetchone()
if row and row[0]:
return json.loads(row[0])
return None
def stats(self) -> dict:
"""Get queue statistics."""
rows = self.conn.execute(
"""SELECT status, COUNT(*), COALESCE(SUM(tokens_used), 0)
FROM jobs WHERE pipeline=? GROUP BY status""",
(self.pipeline,),
).fetchall()
result = {"total": 0, "tokens_used": 0}
for status, count, tokens in rows:
result[status] = count
result["total"] += count
result["tokens_used"] += tokens
return result
# ============================================================
# Orchestrator
# ============================================================
class Orchestrator:
"""
Shared orchestrator for all pipelines.
Features:
- Parallel worker pool (configurable)
- Token budget tracking
- Checkpoint resume
- Rate limiting
- Error retry with exponential backoff
- Final report generation
"""
def __init__(self, pipeline: str, workers: int = 10, token_budget: int = 5_000_000):
self.pipeline = pipeline
self.workers = workers
self.token_budget = token_budget
self.queue = JobQueue(pipeline)
self.conn = self.queue.conn
self._shutdown = False
self._stats = PipelineStats(pipeline=pipeline, tokens_budget=token_budget)
self._rate_limit_delay = 0.1 # seconds between jobs
self._response_cache: Dict[str, dict] = {}
signal.signal(signal.SIGINT, self._handle_signal)
signal.signal(signal.SIGTERM, self._handle_signal)
def _handle_signal(self, signum, frame):
"""Graceful shutdown on signal."""
print(f"\nReceived signal {signum}. Shutting down gracefully...")
self._shutdown = True
def load_jobs(self, jobs_path: str, key_field: str = "id"):
"""Load jobs from a JSONL file into the queue."""
jobs = []
with open(jobs_path) as f:
for line in f:
line = line.strip()
if line:
jobs.append(json.loads(line))
added, skipped = self.queue.enqueue_batch(jobs, key_field)
print(f"Loaded: {added} new, {skipped} existing")
def run(self, job_handler: Callable[[dict], dict] = None):
"""
Run the orchestrator. Processes all pending jobs with parallel workers.
Args:
job_handler: function(job_payload) -> dict with 'tokens_used' key
"""
start = time.time()
self._stats.start_time = datetime.now(timezone.utc).isoformat()
# Record run
self.conn.execute(
"INSERT INTO pipeline_runs (pipeline, started_at) VALUES (?, ?)",
(self.pipeline, self._stats.start_time),
)
run_id = self.conn.execute("SELECT last_insert_rowid()").fetchone()[0]
self.conn.commit()
stats = self.queue.stats()
self._stats.total_jobs = stats.get("pending", 0) + stats.get("retrying", 0)
print(f"\nPipeline: {self.pipeline}")
print(f"Jobs: {self._stats.total_jobs} pending | Workers: {self.workers} | Budget: {self.token_budget:,} tokens")
print()
if self._stats.total_jobs == 0:
print("No jobs to process.")
return
completed = 0
failed = 0
skipped = 0
tokens_used = 0
with ThreadPoolExecutor(max_workers=self.workers) as executor:
futures = {}
while not self._shutdown:
# Check token budget
if tokens_used >= self.token_budget:
print(f"Token budget exhausted ({tokens_used:,}/{self.token_budget:,})")
break
# Fill worker pool
while len(futures) < self.workers and not self._shutdown:
job = self.queue.claim_next()
if not job:
break
# Check response cache (zero-token retries)
job_key = job["job_key"]
payload = json.loads(job["payload"])
cache_key = hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()
if cache_key in self._response_cache:
result = self._response_cache[cache_key]
self.queue.complete(job_key, result, tokens_used=0)
skipped += 1
continue
# Submit to worker
future = executor.submit(self._process_job, job, job_handler)
futures[future] = job
# Rate limiting
time.sleep(self._rate_limit_delay)
if not futures:
break
# Collect results
done = []
for future in as_completed(futures, timeout=1):
job = futures[future]
try:
result = future.result()
if result.get("success"):
tokens = result.get("tokens_used", 0)
tokens_used += tokens
self.queue.complete(job["job_key"], result, tokens_used=tokens)
completed += 1
else:
error = result.get("error", "unknown error")
self.queue.fail(job["job_key"], error, retry=True)
failed += 1
except Exception as e:
self.queue.fail(job["job_key"], str(e), retry=True)
failed += 1
done.append(future)
# Progress update
total = completed + failed + skipped
if total % 10 == 0:
elapsed = time.time() - start
rate = completed / (elapsed / 60) if elapsed > 0 else 0
print(f" Progress: {total}/{self._stats.total_jobs} | "
f"completed={completed} failed={failed} | "
f"tokens={tokens_used:,} | "
f"{rate:.1f}/min")
for f in done:
del futures[f]
# Final report
elapsed = time.time() - start
self._stats.completed = completed
self._stats.failed = failed
self._stats.skipped = skipped
self._stats.tokens_used = tokens_used
self._stats.elapsed_seconds = elapsed
self._stats.jobs_per_minute = completed / (elapsed / 60) if elapsed > 0 else 0
# Save run
self.conn.execute(
"""UPDATE pipeline_runs SET completed_at=?, total_jobs=?, completed=?,
failed=?, tokens_used=?, report=? WHERE id=?""",
(datetime.now(timezone.utc).isoformat(), self._stats.total_jobs,
completed, failed, tokens_used, json.dumps(self._stats.to_dict()), run_id),
)
self.conn.commit()
# Print report
print(f"\n{'='*50}")
print(f"Pipeline: {self.pipeline}")
print(f"Completed: {completed}/{self._stats.total_jobs}")
print(f"Failed: {failed}")
print(f"Skipped (cached): {skipped}")
print(f"Tokens: {tokens_used:,}/{self.token_budget:,}")
print(f"Time: {elapsed:.1f}s ({self._stats.jobs_per_minute:.1f}/min)")
print(f"{'='*50}")
# Save report file
self._save_report()
def _process_job(self, job: dict, handler: Callable = None) -> dict:
"""Process a single job."""
payload = json.loads(job["payload"])
job_key = job["job_key"]
checkpoint = self.queue.get_checkpoint(job_key)
if handler:
try:
result = handler(payload, checkpoint=checkpoint)
return result or {"success": True, "tokens_used": 0}
except Exception as e:
return {"success": False, "error": str(e)}
else:
# Default handler: just mark as complete
return {"success": True, "tokens_used": 0}
def _save_report(self):
"""Save pipeline run report."""
REPORT_DIR.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
path = REPORT_DIR / f"{self.pipeline}_{ts}.json"
with open(path, "w") as f:
json.dump(self._stats.to_dict(), f, indent=2)
print(f"Report: {path}")
def resume(self):
"""Resume failed/retrying jobs from a previous run."""
stats = self.queue.stats()
retrying = stats.get("retrying", 0)
failed = stats.get("failed", 0)
print(f"Resume {self.pipeline}: {retrying} retrying, {failed} failed to reset")
# Reset failed jobs to pending for retry
self.conn.execute(
"UPDATE jobs SET status='pending', attempts=0 WHERE pipeline=? AND status='failed'",
(self.pipeline,),
)
self.conn.execute(
"UPDATE jobs SET status='pending' WHERE pipeline=? AND status='retrying'",
(self.pipeline,),
)
self.conn.commit()
def status(self):
"""Show pipeline status."""
stats = self.queue.stats()
print(f"\nPipeline: {self.pipeline}")
for k, v in sorted(stats.items()):
print(f" {k}: {v}")
# ============================================================
# CLI
# ============================================================
def show_all_status():
"""Show status of all pipelines."""
conn = get_db()
pipelines = conn.execute(
"SELECT DISTINCT pipeline FROM jobs ORDER BY pipeline"
).fetchall()
if not pipelines:
print("No pipelines in database.")
return
print(f"\nAll Pipeline Status")
print(f"{'='*60}")
for (pipeline,) in pipelines:
queue = JobQueue(pipeline, conn)
stats = queue.stats()
total = stats.get("total", 0)
pending = stats.get("pending", 0)
running = stats.get("running", 0)
completed = stats.get("completed", 0)
failed = stats.get("failed", 0)
tokens = stats.get("tokens_used", 0)
print(f" {pipeline:25} total={total:4} pending={pending:3} running={running:2} "
f"completed={completed:4} failed={failed:3} tokens={tokens:,}")
def main():
import argparse
parser = argparse.ArgumentParser(description="Shared Pipeline Orchestrator")
parser.add_argument("--pipeline", "-p", help="Pipeline name")
parser.add_argument("--jobs", "-j", help="Jobs JSONL file to load")
parser.add_argument("--workers", "-w", type=int, default=10, help="Parallel workers")
parser.add_argument("--budget", "-b", type=int, default=5_000_000, help="Token budget")
parser.add_argument("--status", action="store_true", help="Show status")
parser.add_argument("--resume", action="store_true", help="Resume failed jobs")
parser.add_argument("--key-field", default="id", help="Job key field name")
args = parser.parse_args()
if args.status:
if args.pipeline:
orch = Orchestrator(args.pipeline)
orch.status()
else:
show_all_status()
return
if not args.pipeline:
parser.error("--pipeline is required")
orch = Orchestrator(args.pipeline, workers=args.workers, token_budget=args.budget)
if args.jobs:
orch.load_jobs(args.jobs, key_field=args.key_field)
if args.resume:
orch.resume()
if args.jobs or args.resume:
orch.run()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,266 @@
#!/usr/bin/env python3
"""
Training Pair Provenance Tracker — Timmy Foundation
Adds, filters, and reports provenance metadata for JSONL training pairs.
Tracks source_session_id, model, and timestamp for quality auditing.
Usage:
# Tag pairs with provenance
python3 scripts/training_provenance.py tag input.jsonl -o tagged.jsonl \
--session abc123 --model nous/hermes-3
# Filter by model (exclude Anthropic-sourced)
python3 scripts/training_provenance.py filter input.jsonl -o filtered.jsonl \
--exclude-model anthropic
# Report: pair count by source model
python3 scripts/training_provenance.py report input.jsonl
# Pipe support
cat pairs.jsonl | python3 scripts/training_provenance.py report -
"""
import sys
import json
import argparse
from datetime import datetime, timezone
from collections import Counter
from typing import Dict, Any, Optional, List, TextIO
PROVENANCE_KEYS = ["source_session_id", "source_model", "source_timestamp"]
def tag_pair(pair: Dict[str, Any], session_id: Optional[str] = None,
model: Optional[str] = None) -> Dict[str, Any]:
"""Add provenance metadata to a training pair."""
meta = dict(pair.get("_provenance", {}))
if session_id:
meta["source_session_id"] = session_id
if model:
meta["source_model"] = model
meta["source_timestamp"] = datetime.now(timezone.utc).isoformat()
pair["_provenance"] = meta
return pair
def _open_input(path: str) -> TextIO:
"""Open input file or return stdin."""
return sys.stdin if path == "-" else open(path, "r", encoding="utf-8")
def _open_output(path: str) -> TextIO:
"""Open output file or return stdout."""
return sys.stdout if path == "-" else open(path, "w", encoding="utf-8")
def stamp_command(input_path: str, output_path: str,
session_id: Optional[str], model: Optional[str]) -> Dict[str, Any]:
"""Tag all pairs in a file with provenance metadata."""
tagged = 0
skipped = 0
errors = 0
source = _open_input(input_path)
out = _open_output(output_path)
try:
for line in source:
line = line.strip()
if not line:
continue
try:
pair = json.loads(line)
except json.JSONDecodeError:
errors += 1
continue
# Skip if already tagged with same model+session
existing = pair.get("_provenance", {})
if (existing.get("source_model") == model
and existing.get("source_session_id") == session_id):
skipped += 1
out.write(line + "\n")
continue
pair = tag_pair(pair, session_id=session_id, model=model)
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
tagged += 1
finally:
if source is not sys.stdin:
source.close()
# Never close stdout — it breaks downstream piping
return {"tagged": tagged, "skipped": skipped, "errors": errors}
def filter_pairs(input_path: str, output_path: str,
include_models: Optional[List[str]] = None,
exclude_models: Optional[List[str]] = None) -> Dict[str, Any]:
"""Filter pairs by provenance metadata."""
kept = []
removed = []
errors = 0
source = _open_input(input_path)
try:
for line in source:
line = line.strip()
if not line:
continue
try:
pair = json.loads(line)
except json.JSONDecodeError:
errors += 1
continue
prov = pair.get("_provenance", {})
model = prov.get("source_model", "unknown")
should_keep = True
if include_models:
should_keep = should_keep and model in include_models
if exclude_models:
should_keep = should_keep and model not in exclude_models
if should_keep:
kept.append(pair)
else:
removed.append(pair)
finally:
if source is not sys.stdin:
source.close()
# Write output
if output_path:
out = _open_output(output_path)
try:
for pair in kept:
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
finally:
if out is not sys.stdout:
out.close()
return {
"total": len(kept) + len(removed),
"kept": len(kept),
"filtered_out": len(removed),
"errors": errors,
}
def report(input_path: str) -> Dict[str, Any]:
"""Report pair counts by source model and session."""
model_counts: Counter = Counter()
session_counts: Counter = Counter()
tagged = 0
untagged = 0
total = 0
errors = 0
source = _open_input(input_path)
try:
for line in source:
line = line.strip()
if not line:
continue
try:
pair = json.loads(line)
except json.JSONDecodeError:
errors += 1
continue
total += 1
prov = pair.get("_provenance", {})
if prov:
tagged += 1
model = prov.get("source_model", "unknown")
session = prov.get("source_session_id", "unknown")
model_counts[model] += 1
session_counts[session] += 1
else:
untagged += 1
finally:
if source is not sys.stdin:
source.close()
return {
"total": total,
"tagged": tagged,
"untagged": untagged,
"tag_rate": round(tagged / max(total, 1) * 100, 1),
"by_model": dict(model_counts.most_common(20)),
"by_session": dict(session_counts.most_common(10)),
"errors": errors,
}
def main():
parser = argparse.ArgumentParser(description="Training pair provenance tracking")
sub = parser.add_subparsers(dest="command", required=True)
# tag subcommand
tag_p = sub.add_parser("tag", help="Tag pairs with provenance metadata")
tag_p.add_argument("input", help="Input JSONL file (use - for stdin)")
tag_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
tag_p.add_argument("--session", help="Source session ID")
tag_p.add_argument("--model", help="Source model name")
# filter subcommand
filt_p = sub.add_parser("filter", help="Filter pairs by provenance")
filt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
filt_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
filt_p.add_argument("--include-model", action="append", help="Only include these models")
filt_p.add_argument("--exclude-model", action="append", help="Exclude these models")
# report subcommand
rpt_p = sub.add_parser("report", help="Report provenance statistics")
rpt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
args = parser.parse_args()
if args.command == "tag":
result = stamp_command(args.input, args.output, args.session, args.model)
print(f"Tagged: {result['tagged']} Skipped: {result['skipped']} Errors: {result['errors']}", file=sys.stderr)
elif args.command == "filter":
result = filter_pairs(
args.input, args.output,
include_models=args.include_model,
exclude_models=args.exclude_model,
)
print(f"Total: {result['total']} Kept: {result['kept']} Filtered: {result['filtered_out']}", file=sys.stderr)
elif args.command == "report":
result = report(args.input)
print(f"Training Pair Provenance Report", file=sys.stderr)
print(f"{'='*40}", file=sys.stderr)
print(f"Total pairs: {result['total']}", file=sys.stderr)
print(f"Tagged: {result['tagged']} ({result['tag_rate']}%)", file=sys.stderr)
print(f"Untagged: {result['untagged']}", file=sys.stderr)
if result["by_model"]:
print(f"\nBy source model:", file=sys.stderr)
for model, count in result["by_model"].items():
print(f" {model}: {count}", file=sys.stderr)
if result["by_session"]:
print(f"\nBy source session (top 10):", file=sys.stderr)
for session, count in result["by_session"].items():
session_short = session[:12] + "..." if len(session) > 12 else session
print(f" {session_short}: {count}", file=sys.stderr)
# JSON output to stdout
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,363 @@
"""Tests for training pair provenance tracking."""
import json
import os
import subprocess
import sys
import tempfile
import pytest
SCRIPT = os.path.join(os.path.dirname(__file__), "..", "scripts", "training_provenance.py")
def _run(args, stdin=None):
"""Run training_provenance.py and return (stdout, stderr, returncode)."""
result = subprocess.run(
[sys.executable, SCRIPT] + args,
capture_output=True, text=True,
input=stdin,
)
return result.stdout, result.stderr, result.returncode
def _make_pairs(count=3, model="nous/hermes-3", session="sess-123"):
"""Generate test JSONL pairs."""
lines = []
for i in range(count):
lines.append(json.dumps({"terse": f"q{i}", "rich": f"a{i}", "domain": "test"}))
return "\n".join(lines)
# ── tag command ──────────────────────────────────────────────────
class TestTagCommand:
def test_tag_adds_provenance_to_each_pair(self):
pairs = _make_pairs(3)
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pairs)
f.flush()
out_path = f.name + ".tagged"
try:
out, err, rc = _run(["tag", f.name, "-o", out_path,
"--session", "sess-abc", "--model", "nous/hermes-3"])
assert rc == 0
with open(out_path) as f:
lines = [json.loads(l) for l in f if l.strip()]
assert len(lines) == 3
for pair in lines:
prov = pair["_provenance"]
assert prov["source_session_id"] == "sess-abc"
assert prov["source_model"] == "nous/hermes-3"
assert "source_timestamp" in prov
finally:
os.unlink(f.name)
if os.path.exists(out_path):
os.unlink(out_path)
def test_tag_preserves_existing_pair_data(self):
pairs = '{"terse": "hello", "rich": "world", "domain": "greeting"}\n'
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pairs)
f.flush()
out_path = f.name + ".tagged"
try:
_run(["tag", f.name, "-o", out_path, "--model", "m1"])
with open(out_path) as f:
pair = json.loads(f.readline())
assert pair["terse"] == "hello"
assert pair["rich"] == "world"
assert pair["domain"] == "greeting"
assert pair["_provenance"]["source_model"] == "m1"
finally:
os.unlink(f.name)
if os.path.exists(out_path):
os.unlink(out_path)
def test_tag_skips_already_tagged_same_provenance(self):
pair = json.dumps({
"terse": "q", "rich": "a",
"_provenance": {"source_model": "m1", "source_session_id": "s1"}
})
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pair + "\n")
f.flush()
out_path = f.name + ".tagged"
try:
_, err, rc = _run(["tag", f.name, "-o", out_path,
"--session", "s1", "--model", "m1"])
assert rc == 0
assert "Skipped: 1" in err
finally:
os.unlink(f.name)
if os.path.exists(out_path):
os.unlink(out_path)
def test_tag_overwrites_different_provenance(self):
pair = json.dumps({
"terse": "q", "rich": "a",
"_provenance": {"source_model": "old-model", "source_session_id": "old-sess"}
})
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pair + "\n")
f.flush()
out_path = f.name + ".tagged"
try:
_, err, rc = _run(["tag", f.name, "-o", out_path,
"--session", "new-sess", "--model", "new-model"])
assert rc == 0
assert "Tagged: 1" in err
with open(out_path) as f:
tagged = json.loads(f.readline())
assert tagged["_provenance"]["source_model"] == "new-model"
assert tagged["_provenance"]["source_session_id"] == "new-sess"
finally:
os.unlink(f.name)
if os.path.exists(out_path):
os.unlink(out_path)
def test_tag_skips_blank_lines(self):
pairs = '{"t":"a","r":"b"}\n\n{"t":"c","r":"d"}\n'
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pairs)
f.flush()
out_path = f.name + ".tagged"
try:
_, err, rc = _run(["tag", f.name, "-o", out_path, "--model", "m1"])
assert rc == 0
assert "Tagged: 2" in err
assert "Errors: 0" in err
finally:
os.unlink(f.name)
if os.path.exists(out_path):
os.unlink(out_path)
def test_tag_counts_malformed_lines_as_errors(self):
pairs = '{"t":"a"}\nNOT_JSON\n{"t":"b"}\n'
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pairs)
f.flush()
out_path = f.name + ".tagged"
try:
_, err, rc = _run(["tag", f.name, "-o", out_path, "--model", "m1"])
assert rc == 0
assert "Errors: 1" in err
finally:
os.unlink(f.name)
if os.path.exists(out_path):
os.unlink(out_path)
# ── filter command ───────────────────────────────────────────────
class TestFilterCommand:
def test_filter_exclude_model(self):
lines = []
for model in ["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]:
lines.append(json.dumps({
"t": "q", "r": "a",
"_provenance": {"source_model": model}
}))
pairs = "\n".join(lines)
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pairs)
f.flush()
out_path = f.name + ".filtered"
try:
_, err, rc = _run(["filter", f.name, "-o", out_path,
"--exclude-model", "anthropic"])
assert rc == 0
assert "Kept: 2" in err
assert "Filtered: 1" in err
with open(out_path) as f:
kept = [json.loads(l) for l in f if l.strip()]
models = [p["_provenance"]["source_model"] for p in kept]
assert "anthropic/claude" not in models
assert "nous/hermes-3" in models
finally:
os.unlink(f.name)
if os.path.exists(out_path):
os.unlink(out_path)
def test_filter_include_model(self):
lines = []
for model in ["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]:
lines.append(json.dumps({
"t": "q", "r": "a",
"_provenance": {"source_model": model}
}))
pairs = "\n".join(lines)
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pairs)
f.flush()
out_path = f.name + ".filtered"
try:
_, err, rc = _run(["filter", f.name, "-o", out_path,
"--include-model", "nous/hermes-3"])
assert rc == 0
assert "Kept: 1" in err
with open(out_path) as f:
kept = [json.loads(l) for l in f if l.strip()]
assert len(kept) == 1
assert kept[0]["_provenance"]["source_model"] == "nous/hermes-3"
finally:
os.unlink(f.name)
if os.path.exists(out_path):
os.unlink(out_path)
def test_filter_untreated_pairs_have_unknown_model(self):
pairs = '{"t":"q","r":"a"}\n' # no _provenance
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pairs)
f.flush()
out_path = f.name + ".filtered"
try:
# Exclude "unknown" — should filter out unprovenanced pair
_, err, rc = _run(["filter", f.name, "-o", out_path,
"--exclude-model", "unknown"])
assert rc == 0
assert "Kept: 0" in err
finally:
os.unlink(f.name)
if os.path.exists(out_path):
os.unlink(out_path)
# ── report command ───────────────────────────────────────────────
class TestReportCommand:
def test_report_counts_by_model(self):
lines = []
for model in ["nous/hermes-3", "nous/hermes-3", "anthropic/claude"]:
lines.append(json.dumps({
"t": "q", "r": "a",
"_provenance": {"source_model": model, "source_session_id": "s1"}
}))
pairs = "\n".join(lines)
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pairs)
f.flush()
try:
out, err, rc = _run(["report", f.name])
assert rc == 0
result = json.loads(out)
assert result["total"] == 3
assert result["tagged"] == 3
assert result["untagged"] == 0
assert result["tag_rate"] == 100.0
assert result["by_model"]["nous/hermes-3"] == 2
assert result["by_model"]["anthropic/claude"] == 1
finally:
os.unlink(f.name)
def test_report_distinguishes_tagged_vs_untagged(self):
pairs = '{"t":"q","r":"a"}\n{"t":"q2","r":"a2","_provenance":{"source_model":"m1"}}\n'
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pairs)
f.flush()
try:
out, err, rc = _run(["report", f.name])
assert rc == 0
result = json.loads(out)
assert result["total"] == 2
assert result["tagged"] == 1
assert result["untagged"] == 1
assert result["tag_rate"] == 50.0
finally:
os.unlink(f.name)
def test_report_handles_empty_file(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write("")
f.flush()
try:
out, err, rc = _run(["report", f.name])
assert rc == 0
result = json.loads(out)
assert result["total"] == 0
assert result["tag_rate"] == 0
finally:
os.unlink(f.name)
def test_report_counts_by_session(self):
lines = []
for sess in ["sess-a", "sess-a", "sess-b"]:
lines.append(json.dumps({
"t": "q", "r": "a",
"_provenance": {"source_model": "m1", "source_session_id": sess}
}))
pairs = "\n".join(lines)
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
f.write(pairs)
f.flush()
try:
out, _, rc = _run(["report", f.name])
assert rc == 0
result = json.loads(out)
assert result["by_session"]["sess-a"] == 2
assert result["by_session"]["sess-b"] == 1
finally:
os.unlink(f.name)
# ── integration ──────────────────────────────────────────────────
class TestIntegration:
def test_tag_then_filter_then_report(self):
"""Full pipeline: tag → filter → report."""
lines = []
for i, model in enumerate(["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]):
lines.append(json.dumps({"terse": f"q{i}", "rich": f"a{i}", "domain": "test"}))
pairs = "\n".join(lines)
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as src:
src.write(pairs)
src.flush()
tagged_path = src.name + ".tagged"
filtered_path = src.name + ".filtered"
try:
# Step 1: Tag all with session info
_, _, rc = _run(["tag", src.name, "-o", tagged_path,
"--session", "pipe-1", "--model", "nous/hermes-3"])
assert rc == 0
# Step 2: Filter — exclude "unknown" model (untagged pairs)
_, err2, rc = _run(["filter", tagged_path, "-o", filtered_path,
"--exclude-model", "unknown"])
assert rc == 0
assert "Kept: 3" in err2
# Step 3: Report
out, _, rc = _run(["report", filtered_path])
assert rc == 0
result = json.loads(out)
assert result["total"] == 3
assert result["tagged"] == 3
assert result["tag_rate"] == 100.0
finally:
for p in [src.name, tagged_path, filtered_path]:
if os.path.exists(p):
os.unlink(p)