Compare commits
2 Commits
fix/621
...
burn/691-1
| Author | SHA1 | Date | |
|---|---|---|---|
| ae063a8c71 | |||
| 41afbc2ca9 |
@@ -1,568 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
orchestrator.py — Shared Pipeline Orchestrator
|
||||
|
||||
SQLite-backed job queue with parallel workers, token budget tracking,
|
||||
checkpoint resume, rate limiting, and error retry.
|
||||
|
||||
All 5 pipelines use this orchestrator for consistent execution.
|
||||
|
||||
Usage:
|
||||
python3 orchestrator.py --pipeline training_factory --jobs jobs.jsonl
|
||||
python3 orchestrator.py --pipeline adversary --jobs jobs.jsonl --workers 5
|
||||
python3 orchestrator.py --status
|
||||
python3 orchestrator.py --resume training_factory
|
||||
python3 orchestrator.py --report training_factory
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import sqlite3
|
||||
import hashlib
|
||||
import threading
|
||||
import signal
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Any, Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
DB_PATH = Path.home() / ".hermes" / "pipeline" / "orchestrator.db"
|
||||
REPORT_DIR = Path.home() / ".hermes" / "pipeline" / "reports"
|
||||
|
||||
# ============================================================
|
||||
# Data Structures
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class JobStatus:
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
RETRYING = "retrying"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineStats:
|
||||
"""Runtime statistics for a pipeline run."""
|
||||
pipeline: str
|
||||
total_jobs: int = 0
|
||||
completed: int = 0
|
||||
failed: int = 0
|
||||
skipped: int = 0
|
||||
tokens_used: int = 0
|
||||
tokens_budget: int = 5_000_000
|
||||
elapsed_seconds: float = 0.0
|
||||
start_time: str = ""
|
||||
jobs_per_minute: float = 0.0
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"pipeline": self.pipeline,
|
||||
"total_jobs": self.total_jobs,
|
||||
"completed": self.completed,
|
||||
"failed": self.failed,
|
||||
"skipped": self.skipped,
|
||||
"tokens_used": self.tokens_used,
|
||||
"tokens_budget": self.tokens_budget,
|
||||
"elapsed_seconds": round(self.elapsed_seconds, 1),
|
||||
"start_time": self.start_time,
|
||||
"jobs_per_minute": round(self.jobs_per_minute, 2),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Database
|
||||
# ============================================================
|
||||
|
||||
def get_db():
|
||||
"""Get SQLite database connection."""
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH), timeout=30, check_same_thread=False)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
_init_db(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def _init_db(conn):
|
||||
"""Initialize database schema."""
|
||||
conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
pipeline TEXT NOT NULL,
|
||||
job_key TEXT NOT NULL,
|
||||
payload TEXT NOT NULL,
|
||||
status TEXT DEFAULT 'pending',
|
||||
attempts INTEGER DEFAULT 0,
|
||||
max_attempts INTEGER DEFAULT 3,
|
||||
tokens_used INTEGER DEFAULT 0,
|
||||
error TEXT,
|
||||
result TEXT,
|
||||
checkpoint TEXT,
|
||||
created_at TEXT DEFAULT (datetime('now')),
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
UNIQUE(pipeline, job_key)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_pipeline_status ON jobs(pipeline, status);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_pipeline_key ON jobs(pipeline, job_key);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pipeline_runs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
pipeline TEXT NOT NULL,
|
||||
started_at TEXT DEFAULT (datetime('now')),
|
||||
completed_at TEXT,
|
||||
total_jobs INTEGER DEFAULT 0,
|
||||
completed INTEGER DEFAULT 0,
|
||||
failed INTEGER DEFAULT 0,
|
||||
tokens_used INTEGER DEFAULT 0,
|
||||
report TEXT
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Job Queue
|
||||
# ============================================================
|
||||
|
||||
class JobQueue:
|
||||
"""SQLite-backed job queue."""
|
||||
|
||||
def __init__(self, pipeline: str, conn=None):
|
||||
self.pipeline = pipeline
|
||||
self.conn = conn or get_db()
|
||||
|
||||
def enqueue(self, job_key: str, payload: dict, max_attempts: int = 3):
|
||||
"""Add a job to the queue (skip if already exists)."""
|
||||
try:
|
||||
self.conn.execute(
|
||||
"INSERT INTO jobs (pipeline, job_key, payload, max_attempts) VALUES (?, ?, ?, ?)",
|
||||
(self.pipeline, job_key, json.dumps(payload), max_attempts),
|
||||
)
|
||||
self.conn.commit()
|
||||
return True
|
||||
except sqlite3.IntegrityError:
|
||||
# Already exists — check if it needs retry
|
||||
row = self.conn.execute(
|
||||
"SELECT status FROM jobs WHERE pipeline=? AND job_key=?",
|
||||
(self.pipeline, job_key),
|
||||
).fetchone()
|
||||
if row and row[0] == "failed":
|
||||
# Reset for retry
|
||||
self.conn.execute(
|
||||
"UPDATE jobs SET status='pending', attempts=0, error=NULL WHERE pipeline=? AND job_key=?",
|
||||
(self.pipeline, job_key),
|
||||
)
|
||||
self.conn.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def enqueue_batch(self, jobs: List[dict], key_field: str = "id"):
|
||||
"""Enqueue multiple jobs. Returns (added, skipped) counts."""
|
||||
added = 0
|
||||
skipped = 0
|
||||
for job in jobs:
|
||||
key = str(job.get(key_field, hashlib.md5(json.dumps(job).encode()).hexdigest()[:12]))
|
||||
if self.enqueue(key, job):
|
||||
added += 1
|
||||
else:
|
||||
skipped += 1
|
||||
return added, skipped
|
||||
|
||||
def claim_next(self) -> Optional[dict]:
|
||||
"""Claim the next pending job (atomic)."""
|
||||
row = self.conn.execute(
|
||||
"""UPDATE jobs SET status='running', started_at=datetime('now')
|
||||
WHERE id = (
|
||||
SELECT id FROM jobs WHERE pipeline=? AND status IN ('pending', 'retrying')
|
||||
ORDER BY attempts ASC, created_at ASC LIMIT 1
|
||||
) RETURNING *""",
|
||||
(self.pipeline,),
|
||||
).fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
cols = [d[1] for d in self.conn.execute("PRAGMA table_info(jobs)").fetchall()]
|
||||
return dict(zip(cols, row))
|
||||
|
||||
def complete(self, job_key: str, result: dict, tokens_used: int = 0):
|
||||
"""Mark a job as completed."""
|
||||
self.conn.execute(
|
||||
"""UPDATE jobs SET status='completed', completed_at=datetime('now'),
|
||||
result=?, tokens_used=? WHERE pipeline=? AND job_key=?""",
|
||||
(json.dumps(result), tokens_used, self.pipeline, job_key),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def fail(self, job_key: str, error: str, retry: bool = True):
|
||||
"""Mark a job as failed, optionally retry."""
|
||||
row = self.conn.execute(
|
||||
"SELECT attempts, max_attempts FROM jobs WHERE pipeline=? AND job_key=?",
|
||||
(self.pipeline, job_key),
|
||||
).fetchone()
|
||||
|
||||
if not row:
|
||||
return
|
||||
|
||||
attempts, max_attempts = row
|
||||
new_attempts = attempts + 1
|
||||
|
||||
if retry and new_attempts < max_attempts:
|
||||
# Exponential backoff: 2^attempts seconds
|
||||
delay = min(2 ** new_attempts, 60)
|
||||
self.conn.execute(
|
||||
"""UPDATE jobs SET status='retrying', attempts=?, error=?
|
||||
WHERE pipeline=? AND job_key=?""",
|
||||
(new_attempts, error, self.pipeline, job_key),
|
||||
)
|
||||
else:
|
||||
self.conn.execute(
|
||||
"""UPDATE jobs SET status='failed', attempts=?, error=?,
|
||||
completed_at=datetime('now') WHERE pipeline=? AND job_key=?""",
|
||||
(new_attempts, error, self.pipeline, job_key),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def save_checkpoint(self, job_key: str, checkpoint: dict):
|
||||
"""Save progress checkpoint for resume."""
|
||||
self.conn.execute(
|
||||
"UPDATE jobs SET checkpoint=? WHERE pipeline=? AND job_key=?",
|
||||
(json.dumps(checkpoint), self.pipeline, job_key),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def get_checkpoint(self, job_key: str) -> Optional[dict]:
|
||||
"""Get saved checkpoint."""
|
||||
row = self.conn.execute(
|
||||
"SELECT checkpoint FROM jobs WHERE pipeline=? AND job_key=?",
|
||||
(self.pipeline, job_key),
|
||||
).fetchone()
|
||||
if row and row[0]:
|
||||
return json.loads(row[0])
|
||||
return None
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Get queue statistics."""
|
||||
rows = self.conn.execute(
|
||||
"""SELECT status, COUNT(*), COALESCE(SUM(tokens_used), 0)
|
||||
FROM jobs WHERE pipeline=? GROUP BY status""",
|
||||
(self.pipeline,),
|
||||
).fetchall()
|
||||
|
||||
result = {"total": 0, "tokens_used": 0}
|
||||
for status, count, tokens in rows:
|
||||
result[status] = count
|
||||
result["total"] += count
|
||||
result["tokens_used"] += tokens
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Orchestrator
|
||||
# ============================================================
|
||||
|
||||
class Orchestrator:
|
||||
"""
|
||||
Shared orchestrator for all pipelines.
|
||||
|
||||
Features:
|
||||
- Parallel worker pool (configurable)
|
||||
- Token budget tracking
|
||||
- Checkpoint resume
|
||||
- Rate limiting
|
||||
- Error retry with exponential backoff
|
||||
- Final report generation
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline: str, workers: int = 10, token_budget: int = 5_000_000):
|
||||
self.pipeline = pipeline
|
||||
self.workers = workers
|
||||
self.token_budget = token_budget
|
||||
self.queue = JobQueue(pipeline)
|
||||
self.conn = self.queue.conn
|
||||
self._shutdown = False
|
||||
self._stats = PipelineStats(pipeline=pipeline, tokens_budget=token_budget)
|
||||
self._rate_limit_delay = 0.1 # seconds between jobs
|
||||
self._response_cache: Dict[str, dict] = {}
|
||||
|
||||
signal.signal(signal.SIGINT, self._handle_signal)
|
||||
signal.signal(signal.SIGTERM, self._handle_signal)
|
||||
|
||||
def _handle_signal(self, signum, frame):
|
||||
"""Graceful shutdown on signal."""
|
||||
print(f"\nReceived signal {signum}. Shutting down gracefully...")
|
||||
self._shutdown = True
|
||||
|
||||
def load_jobs(self, jobs_path: str, key_field: str = "id"):
|
||||
"""Load jobs from a JSONL file into the queue."""
|
||||
jobs = []
|
||||
with open(jobs_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
jobs.append(json.loads(line))
|
||||
|
||||
added, skipped = self.queue.enqueue_batch(jobs, key_field)
|
||||
print(f"Loaded: {added} new, {skipped} existing")
|
||||
|
||||
def run(self, job_handler: Callable[[dict], dict] = None):
|
||||
"""
|
||||
Run the orchestrator. Processes all pending jobs with parallel workers.
|
||||
|
||||
Args:
|
||||
job_handler: function(job_payload) -> dict with 'tokens_used' key
|
||||
"""
|
||||
start = time.time()
|
||||
self._stats.start_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Record run
|
||||
self.conn.execute(
|
||||
"INSERT INTO pipeline_runs (pipeline, started_at) VALUES (?, ?)",
|
||||
(self.pipeline, self._stats.start_time),
|
||||
)
|
||||
run_id = self.conn.execute("SELECT last_insert_rowid()").fetchone()[0]
|
||||
self.conn.commit()
|
||||
|
||||
stats = self.queue.stats()
|
||||
self._stats.total_jobs = stats.get("pending", 0) + stats.get("retrying", 0)
|
||||
print(f"\nPipeline: {self.pipeline}")
|
||||
print(f"Jobs: {self._stats.total_jobs} pending | Workers: {self.workers} | Budget: {self.token_budget:,} tokens")
|
||||
print()
|
||||
|
||||
if self._stats.total_jobs == 0:
|
||||
print("No jobs to process.")
|
||||
return
|
||||
|
||||
completed = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
tokens_used = 0
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.workers) as executor:
|
||||
futures = {}
|
||||
|
||||
while not self._shutdown:
|
||||
# Check token budget
|
||||
if tokens_used >= self.token_budget:
|
||||
print(f"Token budget exhausted ({tokens_used:,}/{self.token_budget:,})")
|
||||
break
|
||||
|
||||
# Fill worker pool
|
||||
while len(futures) < self.workers and not self._shutdown:
|
||||
job = self.queue.claim_next()
|
||||
if not job:
|
||||
break
|
||||
|
||||
# Check response cache (zero-token retries)
|
||||
job_key = job["job_key"]
|
||||
payload = json.loads(job["payload"])
|
||||
cache_key = hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
if cache_key in self._response_cache:
|
||||
result = self._response_cache[cache_key]
|
||||
self.queue.complete(job_key, result, tokens_used=0)
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Submit to worker
|
||||
future = executor.submit(self._process_job, job, job_handler)
|
||||
futures[future] = job
|
||||
|
||||
# Rate limiting
|
||||
time.sleep(self._rate_limit_delay)
|
||||
|
||||
if not futures:
|
||||
break
|
||||
|
||||
# Collect results
|
||||
done = []
|
||||
for future in as_completed(futures, timeout=1):
|
||||
job = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if result.get("success"):
|
||||
tokens = result.get("tokens_used", 0)
|
||||
tokens_used += tokens
|
||||
self.queue.complete(job["job_key"], result, tokens_used=tokens)
|
||||
completed += 1
|
||||
else:
|
||||
error = result.get("error", "unknown error")
|
||||
self.queue.fail(job["job_key"], error, retry=True)
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
self.queue.fail(job["job_key"], str(e), retry=True)
|
||||
failed += 1
|
||||
|
||||
done.append(future)
|
||||
|
||||
# Progress update
|
||||
total = completed + failed + skipped
|
||||
if total % 10 == 0:
|
||||
elapsed = time.time() - start
|
||||
rate = completed / (elapsed / 60) if elapsed > 0 else 0
|
||||
print(f" Progress: {total}/{self._stats.total_jobs} | "
|
||||
f"completed={completed} failed={failed} | "
|
||||
f"tokens={tokens_used:,} | "
|
||||
f"{rate:.1f}/min")
|
||||
|
||||
for f in done:
|
||||
del futures[f]
|
||||
|
||||
# Final report
|
||||
elapsed = time.time() - start
|
||||
self._stats.completed = completed
|
||||
self._stats.failed = failed
|
||||
self._stats.skipped = skipped
|
||||
self._stats.tokens_used = tokens_used
|
||||
self._stats.elapsed_seconds = elapsed
|
||||
self._stats.jobs_per_minute = completed / (elapsed / 60) if elapsed > 0 else 0
|
||||
|
||||
# Save run
|
||||
self.conn.execute(
|
||||
"""UPDATE pipeline_runs SET completed_at=?, total_jobs=?, completed=?,
|
||||
failed=?, tokens_used=?, report=? WHERE id=?""",
|
||||
(datetime.now(timezone.utc).isoformat(), self._stats.total_jobs,
|
||||
completed, failed, tokens_used, json.dumps(self._stats.to_dict()), run_id),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
# Print report
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Pipeline: {self.pipeline}")
|
||||
print(f"Completed: {completed}/{self._stats.total_jobs}")
|
||||
print(f"Failed: {failed}")
|
||||
print(f"Skipped (cached): {skipped}")
|
||||
print(f"Tokens: {tokens_used:,}/{self.token_budget:,}")
|
||||
print(f"Time: {elapsed:.1f}s ({self._stats.jobs_per_minute:.1f}/min)")
|
||||
print(f"{'='*50}")
|
||||
|
||||
# Save report file
|
||||
self._save_report()
|
||||
|
||||
def _process_job(self, job: dict, handler: Callable = None) -> dict:
|
||||
"""Process a single job."""
|
||||
payload = json.loads(job["payload"])
|
||||
job_key = job["job_key"]
|
||||
checkpoint = self.queue.get_checkpoint(job_key)
|
||||
|
||||
if handler:
|
||||
try:
|
||||
result = handler(payload, checkpoint=checkpoint)
|
||||
return result or {"success": True, "tokens_used": 0}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
else:
|
||||
# Default handler: just mark as complete
|
||||
return {"success": True, "tokens_used": 0}
|
||||
|
||||
def _save_report(self):
|
||||
"""Save pipeline run report."""
|
||||
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
path = REPORT_DIR / f"{self.pipeline}_{ts}.json"
|
||||
with open(path, "w") as f:
|
||||
json.dump(self._stats.to_dict(), f, indent=2)
|
||||
print(f"Report: {path}")
|
||||
|
||||
def resume(self):
|
||||
"""Resume failed/retrying jobs from a previous run."""
|
||||
stats = self.queue.stats()
|
||||
retrying = stats.get("retrying", 0)
|
||||
failed = stats.get("failed", 0)
|
||||
print(f"Resume {self.pipeline}: {retrying} retrying, {failed} failed to reset")
|
||||
|
||||
# Reset failed jobs to pending for retry
|
||||
self.conn.execute(
|
||||
"UPDATE jobs SET status='pending', attempts=0 WHERE pipeline=? AND status='failed'",
|
||||
(self.pipeline,),
|
||||
)
|
||||
self.conn.execute(
|
||||
"UPDATE jobs SET status='pending' WHERE pipeline=? AND status='retrying'",
|
||||
(self.pipeline,),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def status(self):
|
||||
"""Show pipeline status."""
|
||||
stats = self.queue.stats()
|
||||
print(f"\nPipeline: {self.pipeline}")
|
||||
for k, v in sorted(stats.items()):
|
||||
print(f" {k}: {v}")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI
|
||||
# ============================================================
|
||||
|
||||
def show_all_status():
|
||||
"""Show status of all pipelines."""
|
||||
conn = get_db()
|
||||
pipelines = conn.execute(
|
||||
"SELECT DISTINCT pipeline FROM jobs ORDER BY pipeline"
|
||||
).fetchall()
|
||||
|
||||
if not pipelines:
|
||||
print("No pipelines in database.")
|
||||
return
|
||||
|
||||
print(f"\nAll Pipeline Status")
|
||||
print(f"{'='*60}")
|
||||
|
||||
for (pipeline,) in pipelines:
|
||||
queue = JobQueue(pipeline, conn)
|
||||
stats = queue.stats()
|
||||
total = stats.get("total", 0)
|
||||
pending = stats.get("pending", 0)
|
||||
running = stats.get("running", 0)
|
||||
completed = stats.get("completed", 0)
|
||||
failed = stats.get("failed", 0)
|
||||
tokens = stats.get("tokens_used", 0)
|
||||
print(f" {pipeline:25} total={total:4} pending={pending:3} running={running:2} "
|
||||
f"completed={completed:4} failed={failed:3} tokens={tokens:,}")
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Shared Pipeline Orchestrator")
|
||||
parser.add_argument("--pipeline", "-p", help="Pipeline name")
|
||||
parser.add_argument("--jobs", "-j", help="Jobs JSONL file to load")
|
||||
parser.add_argument("--workers", "-w", type=int, default=10, help="Parallel workers")
|
||||
parser.add_argument("--budget", "-b", type=int, default=5_000_000, help="Token budget")
|
||||
parser.add_argument("--status", action="store_true", help="Show status")
|
||||
parser.add_argument("--resume", action="store_true", help="Resume failed jobs")
|
||||
parser.add_argument("--key-field", default="id", help="Job key field name")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.status:
|
||||
if args.pipeline:
|
||||
orch = Orchestrator(args.pipeline)
|
||||
orch.status()
|
||||
else:
|
||||
show_all_status()
|
||||
return
|
||||
|
||||
if not args.pipeline:
|
||||
parser.error("--pipeline is required")
|
||||
|
||||
orch = Orchestrator(args.pipeline, workers=args.workers, token_budget=args.budget)
|
||||
|
||||
if args.jobs:
|
||||
orch.load_jobs(args.jobs, key_field=args.key_field)
|
||||
|
||||
if args.resume:
|
||||
orch.resume()
|
||||
|
||||
if args.jobs or args.resume:
|
||||
orch.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
266
scripts/training_provenance.py
Normal file
266
scripts/training_provenance.py
Normal file
@@ -0,0 +1,266 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Pair Provenance Tracker — Timmy Foundation
|
||||
|
||||
Adds, filters, and reports provenance metadata for JSONL training pairs.
|
||||
Tracks source_session_id, model, and timestamp for quality auditing.
|
||||
|
||||
Usage:
|
||||
# Tag pairs with provenance
|
||||
python3 scripts/training_provenance.py tag input.jsonl -o tagged.jsonl \
|
||||
--session abc123 --model nous/hermes-3
|
||||
|
||||
# Filter by model (exclude Anthropic-sourced)
|
||||
python3 scripts/training_provenance.py filter input.jsonl -o filtered.jsonl \
|
||||
--exclude-model anthropic
|
||||
|
||||
# Report: pair count by source model
|
||||
python3 scripts/training_provenance.py report input.jsonl
|
||||
|
||||
# Pipe support
|
||||
cat pairs.jsonl | python3 scripts/training_provenance.py report -
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from datetime import datetime, timezone
|
||||
from collections import Counter
|
||||
from typing import Dict, Any, Optional, List, TextIO
|
||||
|
||||
|
||||
PROVENANCE_KEYS = ["source_session_id", "source_model", "source_timestamp"]
|
||||
|
||||
|
||||
def tag_pair(pair: Dict[str, Any], session_id: Optional[str] = None,
|
||||
model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Add provenance metadata to a training pair."""
|
||||
meta = dict(pair.get("_provenance", {}))
|
||||
|
||||
if session_id:
|
||||
meta["source_session_id"] = session_id
|
||||
if model:
|
||||
meta["source_model"] = model
|
||||
meta["source_timestamp"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pair["_provenance"] = meta
|
||||
return pair
|
||||
|
||||
|
||||
def _open_input(path: str) -> TextIO:
|
||||
"""Open input file or return stdin."""
|
||||
return sys.stdin if path == "-" else open(path, "r", encoding="utf-8")
|
||||
|
||||
|
||||
def _open_output(path: str) -> TextIO:
|
||||
"""Open output file or return stdout."""
|
||||
return sys.stdout if path == "-" else open(path, "w", encoding="utf-8")
|
||||
|
||||
|
||||
def stamp_command(input_path: str, output_path: str,
|
||||
session_id: Optional[str], model: Optional[str]) -> Dict[str, Any]:
|
||||
"""Tag all pairs in a file with provenance metadata."""
|
||||
tagged = 0
|
||||
skipped = 0
|
||||
errors = 0
|
||||
|
||||
source = _open_input(input_path)
|
||||
out = _open_output(output_path)
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
# Skip if already tagged with same model+session
|
||||
existing = pair.get("_provenance", {})
|
||||
if (existing.get("source_model") == model
|
||||
and existing.get("source_session_id") == session_id):
|
||||
skipped += 1
|
||||
out.write(line + "\n")
|
||||
continue
|
||||
|
||||
pair = tag_pair(pair, session_id=session_id, model=model)
|
||||
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||||
tagged += 1
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
# Never close stdout — it breaks downstream piping
|
||||
|
||||
return {"tagged": tagged, "skipped": skipped, "errors": errors}
|
||||
|
||||
|
||||
def filter_pairs(input_path: str, output_path: str,
|
||||
include_models: Optional[List[str]] = None,
|
||||
exclude_models: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""Filter pairs by provenance metadata."""
|
||||
kept = []
|
||||
removed = []
|
||||
errors = 0
|
||||
|
||||
source = _open_input(input_path)
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
prov = pair.get("_provenance", {})
|
||||
model = prov.get("source_model", "unknown")
|
||||
|
||||
should_keep = True
|
||||
|
||||
if include_models:
|
||||
should_keep = should_keep and model in include_models
|
||||
|
||||
if exclude_models:
|
||||
should_keep = should_keep and model not in exclude_models
|
||||
|
||||
if should_keep:
|
||||
kept.append(pair)
|
||||
else:
|
||||
removed.append(pair)
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
|
||||
# Write output
|
||||
if output_path:
|
||||
out = _open_output(output_path)
|
||||
try:
|
||||
for pair in kept:
|
||||
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||||
finally:
|
||||
if out is not sys.stdout:
|
||||
out.close()
|
||||
|
||||
return {
|
||||
"total": len(kept) + len(removed),
|
||||
"kept": len(kept),
|
||||
"filtered_out": len(removed),
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
def report(input_path: str) -> Dict[str, Any]:
|
||||
"""Report pair counts by source model and session."""
|
||||
model_counts: Counter = Counter()
|
||||
session_counts: Counter = Counter()
|
||||
tagged = 0
|
||||
untagged = 0
|
||||
total = 0
|
||||
errors = 0
|
||||
|
||||
source = _open_input(input_path)
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
total += 1
|
||||
prov = pair.get("_provenance", {})
|
||||
|
||||
if prov:
|
||||
tagged += 1
|
||||
model = prov.get("source_model", "unknown")
|
||||
session = prov.get("source_session_id", "unknown")
|
||||
model_counts[model] += 1
|
||||
session_counts[session] += 1
|
||||
else:
|
||||
untagged += 1
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"tagged": tagged,
|
||||
"untagged": untagged,
|
||||
"tag_rate": round(tagged / max(total, 1) * 100, 1),
|
||||
"by_model": dict(model_counts.most_common(20)),
|
||||
"by_session": dict(session_counts.most_common(10)),
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Training pair provenance tracking")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# tag subcommand
|
||||
tag_p = sub.add_parser("tag", help="Tag pairs with provenance metadata")
|
||||
tag_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
tag_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
|
||||
tag_p.add_argument("--session", help="Source session ID")
|
||||
tag_p.add_argument("--model", help="Source model name")
|
||||
|
||||
# filter subcommand
|
||||
filt_p = sub.add_parser("filter", help="Filter pairs by provenance")
|
||||
filt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
filt_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
|
||||
filt_p.add_argument("--include-model", action="append", help="Only include these models")
|
||||
filt_p.add_argument("--exclude-model", action="append", help="Exclude these models")
|
||||
|
||||
# report subcommand
|
||||
rpt_p = sub.add_parser("report", help="Report provenance statistics")
|
||||
rpt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "tag":
|
||||
result = stamp_command(args.input, args.output, args.session, args.model)
|
||||
print(f"Tagged: {result['tagged']} Skipped: {result['skipped']} Errors: {result['errors']}", file=sys.stderr)
|
||||
|
||||
elif args.command == "filter":
|
||||
result = filter_pairs(
|
||||
args.input, args.output,
|
||||
include_models=args.include_model,
|
||||
exclude_models=args.exclude_model,
|
||||
)
|
||||
print(f"Total: {result['total']} Kept: {result['kept']} Filtered: {result['filtered_out']}", file=sys.stderr)
|
||||
|
||||
elif args.command == "report":
|
||||
result = report(args.input)
|
||||
print(f"Training Pair Provenance Report", file=sys.stderr)
|
||||
print(f"{'='*40}", file=sys.stderr)
|
||||
print(f"Total pairs: {result['total']}", file=sys.stderr)
|
||||
print(f"Tagged: {result['tagged']} ({result['tag_rate']}%)", file=sys.stderr)
|
||||
print(f"Untagged: {result['untagged']}", file=sys.stderr)
|
||||
|
||||
if result["by_model"]:
|
||||
print(f"\nBy source model:", file=sys.stderr)
|
||||
for model, count in result["by_model"].items():
|
||||
print(f" {model}: {count}", file=sys.stderr)
|
||||
|
||||
if result["by_session"]:
|
||||
print(f"\nBy source session (top 10):", file=sys.stderr)
|
||||
for session, count in result["by_session"].items():
|
||||
session_short = session[:12] + "..." if len(session) > 12 else session
|
||||
print(f" {session_short}: {count}", file=sys.stderr)
|
||||
|
||||
# JSON output to stdout
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
363
tests/test_training_provenance.py
Normal file
363
tests/test_training_provenance.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Tests for training pair provenance tracking."""
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
|
||||
SCRIPT = os.path.join(os.path.dirname(__file__), "..", "scripts", "training_provenance.py")
|
||||
|
||||
|
||||
def _run(args, stdin=None):
|
||||
"""Run training_provenance.py and return (stdout, stderr, returncode)."""
|
||||
result = subprocess.run(
|
||||
[sys.executable, SCRIPT] + args,
|
||||
capture_output=True, text=True,
|
||||
input=stdin,
|
||||
)
|
||||
return result.stdout, result.stderr, result.returncode
|
||||
|
||||
|
||||
def _make_pairs(count=3, model="nous/hermes-3", session="sess-123"):
|
||||
"""Generate test JSONL pairs."""
|
||||
lines = []
|
||||
for i in range(count):
|
||||
lines.append(json.dumps({"terse": f"q{i}", "rich": f"a{i}", "domain": "test"}))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ── tag command ──────────────────────────────────────────────────
|
||||
|
||||
class TestTagCommand:
|
||||
def test_tag_adds_provenance_to_each_pair(self):
|
||||
pairs = _make_pairs(3)
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["tag", f.name, "-o", out_path,
|
||||
"--session", "sess-abc", "--model", "nous/hermes-3"])
|
||||
assert rc == 0
|
||||
|
||||
with open(out_path) as f:
|
||||
lines = [json.loads(l) for l in f if l.strip()]
|
||||
|
||||
assert len(lines) == 3
|
||||
for pair in lines:
|
||||
prov = pair["_provenance"]
|
||||
assert prov["source_session_id"] == "sess-abc"
|
||||
assert prov["source_model"] == "nous/hermes-3"
|
||||
assert "source_timestamp" in prov
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_preserves_existing_pair_data(self):
|
||||
pairs = '{"terse": "hello", "rich": "world", "domain": "greeting"}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_run(["tag", f.name, "-o", out_path, "--model", "m1"])
|
||||
with open(out_path) as f:
|
||||
pair = json.loads(f.readline())
|
||||
assert pair["terse"] == "hello"
|
||||
assert pair["rich"] == "world"
|
||||
assert pair["domain"] == "greeting"
|
||||
assert pair["_provenance"]["source_model"] == "m1"
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_skips_already_tagged_same_provenance(self):
|
||||
pair = json.dumps({
|
||||
"terse": "q", "rich": "a",
|
||||
"_provenance": {"source_model": "m1", "source_session_id": "s1"}
|
||||
})
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pair + "\n")
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path,
|
||||
"--session", "s1", "--model", "m1"])
|
||||
assert rc == 0
|
||||
assert "Skipped: 1" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_overwrites_different_provenance(self):
|
||||
pair = json.dumps({
|
||||
"terse": "q", "rich": "a",
|
||||
"_provenance": {"source_model": "old-model", "source_session_id": "old-sess"}
|
||||
})
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pair + "\n")
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path,
|
||||
"--session", "new-sess", "--model", "new-model"])
|
||||
assert rc == 0
|
||||
assert "Tagged: 1" in err
|
||||
|
||||
with open(out_path) as f:
|
||||
tagged = json.loads(f.readline())
|
||||
assert tagged["_provenance"]["source_model"] == "new-model"
|
||||
assert tagged["_provenance"]["source_session_id"] == "new-sess"
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_skips_blank_lines(self):
|
||||
pairs = '{"t":"a","r":"b"}\n\n{"t":"c","r":"d"}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path, "--model", "m1"])
|
||||
assert rc == 0
|
||||
assert "Tagged: 2" in err
|
||||
assert "Errors: 0" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_counts_malformed_lines_as_errors(self):
|
||||
pairs = '{"t":"a"}\nNOT_JSON\n{"t":"b"}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path, "--model", "m1"])
|
||||
assert rc == 0
|
||||
assert "Errors: 1" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
|
||||
# ── filter command ───────────────────────────────────────────────
|
||||
|
||||
class TestFilterCommand:
|
||||
def test_filter_exclude_model(self):
|
||||
lines = []
|
||||
for model in ["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": model}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".filtered"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["filter", f.name, "-o", out_path,
|
||||
"--exclude-model", "anthropic"])
|
||||
assert rc == 0
|
||||
assert "Kept: 2" in err
|
||||
assert "Filtered: 1" in err
|
||||
|
||||
with open(out_path) as f:
|
||||
kept = [json.loads(l) for l in f if l.strip()]
|
||||
models = [p["_provenance"]["source_model"] for p in kept]
|
||||
assert "anthropic/claude" not in models
|
||||
assert "nous/hermes-3" in models
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_filter_include_model(self):
|
||||
lines = []
|
||||
for model in ["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": model}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".filtered"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["filter", f.name, "-o", out_path,
|
||||
"--include-model", "nous/hermes-3"])
|
||||
assert rc == 0
|
||||
assert "Kept: 1" in err
|
||||
|
||||
with open(out_path) as f:
|
||||
kept = [json.loads(l) for l in f if l.strip()]
|
||||
assert len(kept) == 1
|
||||
assert kept[0]["_provenance"]["source_model"] == "nous/hermes-3"
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_filter_untreated_pairs_have_unknown_model(self):
|
||||
pairs = '{"t":"q","r":"a"}\n' # no _provenance
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".filtered"
|
||||
|
||||
try:
|
||||
# Exclude "unknown" — should filter out unprovenanced pair
|
||||
_, err, rc = _run(["filter", f.name, "-o", out_path,
|
||||
"--exclude-model", "unknown"])
|
||||
assert rc == 0
|
||||
assert "Kept: 0" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
|
||||
# ── report command ───────────────────────────────────────────────
|
||||
|
||||
class TestReportCommand:
|
||||
def test_report_counts_by_model(self):
|
||||
lines = []
|
||||
for model in ["nous/hermes-3", "nous/hermes-3", "anthropic/claude"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": model, "source_session_id": "s1"}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 3
|
||||
assert result["tagged"] == 3
|
||||
assert result["untagged"] == 0
|
||||
assert result["tag_rate"] == 100.0
|
||||
assert result["by_model"]["nous/hermes-3"] == 2
|
||||
assert result["by_model"]["anthropic/claude"] == 1
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_report_distinguishes_tagged_vs_untagged(self):
|
||||
pairs = '{"t":"q","r":"a"}\n{"t":"q2","r":"a2","_provenance":{"source_model":"m1"}}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 2
|
||||
assert result["tagged"] == 1
|
||||
assert result["untagged"] == 1
|
||||
assert result["tag_rate"] == 50.0
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_report_handles_empty_file(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write("")
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 0
|
||||
assert result["tag_rate"] == 0
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_report_counts_by_session(self):
|
||||
lines = []
|
||||
for sess in ["sess-a", "sess-a", "sess-b"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": "m1", "source_session_id": sess}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, _, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["by_session"]["sess-a"] == 2
|
||||
assert result["by_session"]["sess-b"] == 1
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
# ── integration ──────────────────────────────────────────────────
|
||||
|
||||
class TestIntegration:
|
||||
def test_tag_then_filter_then_report(self):
|
||||
"""Full pipeline: tag → filter → report."""
|
||||
lines = []
|
||||
for i, model in enumerate(["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]):
|
||||
lines.append(json.dumps({"terse": f"q{i}", "rich": f"a{i}", "domain": "test"}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as src:
|
||||
src.write(pairs)
|
||||
src.flush()
|
||||
|
||||
tagged_path = src.name + ".tagged"
|
||||
filtered_path = src.name + ".filtered"
|
||||
|
||||
try:
|
||||
# Step 1: Tag all with session info
|
||||
_, _, rc = _run(["tag", src.name, "-o", tagged_path,
|
||||
"--session", "pipe-1", "--model", "nous/hermes-3"])
|
||||
assert rc == 0
|
||||
|
||||
# Step 2: Filter — exclude "unknown" model (untagged pairs)
|
||||
_, err2, rc = _run(["filter", tagged_path, "-o", filtered_path,
|
||||
"--exclude-model", "unknown"])
|
||||
assert rc == 0
|
||||
assert "Kept: 3" in err2
|
||||
|
||||
# Step 3: Report
|
||||
out, _, rc = _run(["report", filtered_path])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 3
|
||||
assert result["tagged"] == 3
|
||||
assert result["tag_rate"] == 100.0
|
||||
finally:
|
||||
for p in [src.name, tagged_path, filtered_path]:
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
Reference in New Issue
Block a user