Compare commits
2 Commits
fix/624
...
burn/691-1
| Author | SHA1 | Date | |
|---|---|---|---|
| ae063a8c71 | |||
| 41afbc2ca9 |
@@ -1,331 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
nightly_scheduler.py — Nightly Pipeline Scheduler
|
||||
|
||||
Auto-starts batch pipelines when inference is available, respecting
|
||||
priority ordering, token budgets, and peak-hour pausing.
|
||||
|
||||
Usage:
|
||||
python3 nightly_scheduler.py # run scheduler
|
||||
python3 nightly_scheduler.py --check # dry-run: show what would start
|
||||
python3 nightly_scheduler.py --status # show pipeline status
|
||||
python3 nightly_scheduler.py --reset # reset daily budget
|
||||
|
||||
Crontab:
|
||||
# Run every 30 minutes during off-peak hours (10pm-6am)
|
||||
*/30 22-5 * * * cd /path/to/timmy-config && python3 pipeline/nightly_scheduler.py >> ~/.hermes/pipeline-logs/nightly.log 2>&1
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
# --- Config ---
|
||||
STATE_FILE = Path.home() / ".hermes" / "pipeline_state.json"
|
||||
LOG_DIR = Path.home() / ".hermes" / "pipeline-logs"
|
||||
DAILY_TOKEN_BUDGET = 5_000_000 # 5M tokens per day
|
||||
PEAK_HOURS = list(range(8, 22)) # 8am-10pm = peak interactive usage
|
||||
CHECK_INTERVAL = 1800 # 30 minutes
|
||||
|
||||
INFERENCE_ENDPOINTS = [
|
||||
{"name": "local_ollama", "url": "http://localhost:11434/v1/models", "type": "local"},
|
||||
{"name": "runpod", "url": "https://8lfr3j47a5r3gn-11434.proxy.runpod.net/v1/models", "type": "gpu"},
|
||||
{"name": "openrouter", "url": "https://openrouter.ai/api/v1/models", "type": "cloud"},
|
||||
]
|
||||
|
||||
# Pipeline priority order (highest first)
|
||||
PIPELINE_PRIORITY = [
|
||||
{"name": "playground_factory", "script": "pipeline/playground_factory.py", "priority": 1},
|
||||
{"name": "training_factory", "script": "pipeline/training_factory.py", "priority": 2},
|
||||
{"name": "knowledge_mine", "script": "pipeline/knowledge_mine.py", "priority": 3},
|
||||
{"name": "adversary", "script": "pipeline/adversary_runner.py", "priority": 4},
|
||||
{"name": "codebase_genome", "script": "pipeline/codebase_genome.py", "priority": 5},
|
||||
]
|
||||
|
||||
# Dependency rules: some pipelines only start after others are running
|
||||
DEPENDENCY_RULES = {
|
||||
"playground_factory": [], # no deps, start immediately
|
||||
"training_factory": [], # no deps, start in parallel
|
||||
"knowledge_mine": ["training_factory"], # start after training is running
|
||||
"adversary": ["knowledge_mine"], # start after knowledge is halfway
|
||||
"codebase_genome": [], # continuous, one repo per night
|
||||
}
|
||||
|
||||
|
||||
def load_state():
|
||||
"""Load pipeline state from disk."""
|
||||
if STATE_FILE.exists():
|
||||
with open(STATE_FILE) as f:
|
||||
return json.load(f)
|
||||
return {
|
||||
"last_run": None,
|
||||
"daily_tokens_used": 0,
|
||||
"budget_reset_date": None,
|
||||
"pipelines": {},
|
||||
"active_sessions": [],
|
||||
}
|
||||
|
||||
|
||||
def save_state(state):
|
||||
"""Save pipeline state to disk."""
|
||||
STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(STATE_FILE, "w") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
|
||||
|
||||
def check_provider(endpoint):
|
||||
"""Check if an inference provider is available."""
|
||||
try:
|
||||
req = urllib.request.Request(endpoint["url"], headers={"Authorization": "Bearer ollama"})
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
return resp.status == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_available_providers():
|
||||
"""Check all inference endpoints and return available ones."""
|
||||
available = []
|
||||
for ep in INFERENCE_ENDPOINTS:
|
||||
if check_provider(ep):
|
||||
available.append(ep["name"])
|
||||
return available
|
||||
|
||||
|
||||
def is_peak_hours():
|
||||
"""Check if current time is during peak interactive usage."""
|
||||
now = datetime.now()
|
||||
return now.hour in PEAK_HOURS
|
||||
|
||||
|
||||
def check_token_budget(state):
|
||||
"""Check if daily token budget allows starting new work."""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
if state.get("budget_reset_date") != today:
|
||||
# New day, reset budget
|
||||
state["daily_tokens_used"] = 0
|
||||
state["budget_reset_date"] = today
|
||||
save_state(state)
|
||||
return state["daily_tokens_used"] < DAILY_TOKEN_BUDGET
|
||||
|
||||
|
||||
def get_pipeline_status(state, pipeline_name):
|
||||
"""Get the status of a specific pipeline."""
|
||||
return state.get("pipelines", {}).get(pipeline_name, {
|
||||
"status": "not_started",
|
||||
"last_run": None,
|
||||
"last_success": None,
|
||||
"progress": 0,
|
||||
})
|
||||
|
||||
|
||||
def check_dependencies(state, pipeline_name):
|
||||
"""Check if pipeline dependencies are satisfied."""
|
||||
deps = DEPENDENCY_RULES.get(pipeline_name, [])
|
||||
for dep in deps:
|
||||
dep_status = get_pipeline_status(state, dep)
|
||||
if dep_status["status"] not in ("running", "completed"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def start_pipeline(pipeline, state, dry_run=False):
|
||||
"""Start a pipeline process."""
|
||||
name = pipeline["name"]
|
||||
script = pipeline["script"]
|
||||
|
||||
log(f"Starting pipeline: {name}")
|
||||
|
||||
if dry_run:
|
||||
log(f" DRY RUN — would run: python3 {script}")
|
||||
return True
|
||||
|
||||
# Check if script exists
|
||||
script_path = Path(script)
|
||||
if not script_path.exists():
|
||||
log(f" Script not found: {script_path}")
|
||||
# Update state anyway so we track the attempt
|
||||
state["pipelines"][name] = {
|
||||
"status": "script_missing",
|
||||
"last_run": datetime.now(timezone.utc).isoformat(),
|
||||
"progress": 0,
|
||||
}
|
||||
save_state(state)
|
||||
return False
|
||||
|
||||
# Run the pipeline script
|
||||
import subprocess
|
||||
log_dir = LOG_DIR / name
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file = log_dir / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||||
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
["python3", str(script_path)],
|
||||
stdout=open(log_file, "w"),
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=str(Path(script).parent.parent),
|
||||
)
|
||||
|
||||
state["pipelines"][name] = {
|
||||
"status": "running",
|
||||
"pid": proc.pid,
|
||||
"last_run": datetime.now(timezone.utc).isoformat(),
|
||||
"log_file": str(log_file),
|
||||
"progress": 0,
|
||||
}
|
||||
save_state(state)
|
||||
log(f" Started PID {proc.pid}, log: {log_file}")
|
||||
return True
|
||||
except Exception as e:
|
||||
log(f" Failed to start: {e}")
|
||||
state["pipelines"][name] = {
|
||||
"status": "failed",
|
||||
"last_run": datetime.now(timezone.utc).isoformat(),
|
||||
"error": str(e),
|
||||
}
|
||||
save_state(state)
|
||||
return False
|
||||
|
||||
|
||||
def check_running_pipelines(state):
|
||||
"""Check status of running pipelines and update state."""
|
||||
import subprocess
|
||||
for name, info in state.get("pipelines", {}).items():
|
||||
if info.get("status") == "running":
|
||||
pid = info.get("pid")
|
||||
if pid:
|
||||
try:
|
||||
os.kill(pid, 0) # Check if process exists
|
||||
except ProcessLookupError:
|
||||
# Process finished
|
||||
info["status"] = "completed"
|
||||
info["completed_at"] = datetime.now(timezone.utc).isoformat()
|
||||
log(f"Pipeline {name} completed (PID {pid} exited)")
|
||||
save_state(state)
|
||||
|
||||
|
||||
def run_scheduler(dry_run=False, check_only=False):
|
||||
"""Main scheduler loop."""
|
||||
state = load_state()
|
||||
|
||||
log("=" * 50)
|
||||
log(f"Pipeline Scheduler — {datetime.now().isoformat()}")
|
||||
log(f"Mode: {'CHECK' if check_only else 'DRY RUN' if dry_run else 'LIVE'}")
|
||||
|
||||
# Check peak hours
|
||||
if is_peak_hours():
|
||||
log("Peak hours detected. Pausing pipeline starts.")
|
||||
log("Pipelines will resume at 10pm.")
|
||||
return
|
||||
|
||||
# Check token budget
|
||||
if not check_token_budget(state):
|
||||
log(f"Daily token budget exhausted ({state['daily_tokens_used']}/{DAILY_TOKEN_BUDGET})")
|
||||
return
|
||||
log(f"Token budget: {state['daily_tokens_used']}/{DAILY_TOKEN_BUDGET}")
|
||||
|
||||
# Check providers
|
||||
providers = get_available_providers()
|
||||
if not providers:
|
||||
log("No inference providers available. Skipping.")
|
||||
return
|
||||
log(f"Available providers: {', '.join(providers)}")
|
||||
|
||||
# Check running pipelines
|
||||
check_running_pipelines(state)
|
||||
|
||||
# Find next pipeline to start
|
||||
started = 0
|
||||
for pipeline in sorted(PIPELINE_PRIORITY, key=lambda p: p["priority"]):
|
||||
name = pipeline["name"]
|
||||
status = get_pipeline_status(state, name)
|
||||
|
||||
# Skip if already running or completed
|
||||
if status["status"] in ("running", "completed"):
|
||||
log(f" {name}: {status['status']} (skipping)")
|
||||
continue
|
||||
|
||||
# Check dependencies
|
||||
if not check_dependencies(state, name):
|
||||
deps = DEPENDENCY_RULES.get(name, [])
|
||||
log(f" {name}: waiting for dependencies: {deps}")
|
||||
continue
|
||||
|
||||
# Start the pipeline
|
||||
if check_only:
|
||||
log(f" {name}: READY to start (priority {pipeline['priority']})")
|
||||
else:
|
||||
if start_pipeline(pipeline, state, dry_run):
|
||||
started += 1
|
||||
# Only start one pipeline per run to avoid overload
|
||||
if started >= 1:
|
||||
log("Started 1 pipeline. Will check again next cycle.")
|
||||
break
|
||||
|
||||
if started == 0 and not check_only:
|
||||
log("No pipelines to start. All are running, completed, or blocked.")
|
||||
|
||||
log("=" * 50)
|
||||
|
||||
|
||||
def show_status():
|
||||
"""Show current pipeline status."""
|
||||
state = load_state()
|
||||
print(f"\nPipeline Status — {datetime.now().strftime('%Y-%m-%d %H:%M')}")
|
||||
print(f"Token budget: {state.get('daily_tokens_used', 0)}/{DAILY_TOKEN_BUDGET}")
|
||||
print(f"Last run: {state.get('last_run', 'never')}")
|
||||
print()
|
||||
|
||||
for pipeline in sorted(PIPELINE_PRIORITY, key=lambda p: p["priority"]):
|
||||
name = pipeline["name"]
|
||||
status = get_pipeline_status(state, name)
|
||||
st = status["status"]
|
||||
icon = {"running": "●", "completed": "✓", "failed": "✗", "not_started": "○", "script_missing": "?"}.get(st, "?")
|
||||
print(f" {icon} {name:25} {st:15} last={(status.get('last_run') or 'never')[:19]}")
|
||||
|
||||
|
||||
def reset_budget():
|
||||
"""Reset daily token budget."""
|
||||
state = load_state()
|
||||
state["daily_tokens_used"] = 0
|
||||
state["budget_reset_date"] = datetime.now().strftime("%Y-%m-%d")
|
||||
save_state(state)
|
||||
print("Budget reset.")
|
||||
|
||||
|
||||
def log(msg):
|
||||
"""Log to stdout and file."""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
line = f"[{timestamp}] {msg}"
|
||||
print(line)
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
log_file = LOG_DIR / "nightly.log"
|
||||
with open(log_file, "a") as f:
|
||||
f.write(line + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Nightly Pipeline Scheduler")
|
||||
parser.add_argument("--check", action="store_true", help="Dry-run: show what would start")
|
||||
parser.add_argument("--status", action="store_true", help="Show pipeline status")
|
||||
parser.add_argument("--reset", action="store_true", help="Reset daily token budget")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Dry-run mode")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.status:
|
||||
show_status()
|
||||
elif args.reset:
|
||||
reset_budget()
|
||||
else:
|
||||
run_scheduler(dry_run=args.dry_run or args.check, check_only=args.check)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
266
scripts/training_provenance.py
Normal file
266
scripts/training_provenance.py
Normal file
@@ -0,0 +1,266 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Pair Provenance Tracker — Timmy Foundation
|
||||
|
||||
Adds, filters, and reports provenance metadata for JSONL training pairs.
|
||||
Tracks source_session_id, model, and timestamp for quality auditing.
|
||||
|
||||
Usage:
|
||||
# Tag pairs with provenance
|
||||
python3 scripts/training_provenance.py tag input.jsonl -o tagged.jsonl \
|
||||
--session abc123 --model nous/hermes-3
|
||||
|
||||
# Filter by model (exclude Anthropic-sourced)
|
||||
python3 scripts/training_provenance.py filter input.jsonl -o filtered.jsonl \
|
||||
--exclude-model anthropic
|
||||
|
||||
# Report: pair count by source model
|
||||
python3 scripts/training_provenance.py report input.jsonl
|
||||
|
||||
# Pipe support
|
||||
cat pairs.jsonl | python3 scripts/training_provenance.py report -
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from datetime import datetime, timezone
|
||||
from collections import Counter
|
||||
from typing import Dict, Any, Optional, List, TextIO
|
||||
|
||||
|
||||
PROVENANCE_KEYS = ["source_session_id", "source_model", "source_timestamp"]
|
||||
|
||||
|
||||
def tag_pair(pair: Dict[str, Any], session_id: Optional[str] = None,
|
||||
model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Add provenance metadata to a training pair."""
|
||||
meta = dict(pair.get("_provenance", {}))
|
||||
|
||||
if session_id:
|
||||
meta["source_session_id"] = session_id
|
||||
if model:
|
||||
meta["source_model"] = model
|
||||
meta["source_timestamp"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pair["_provenance"] = meta
|
||||
return pair
|
||||
|
||||
|
||||
def _open_input(path: str) -> TextIO:
|
||||
"""Open input file or return stdin."""
|
||||
return sys.stdin if path == "-" else open(path, "r", encoding="utf-8")
|
||||
|
||||
|
||||
def _open_output(path: str) -> TextIO:
|
||||
"""Open output file or return stdout."""
|
||||
return sys.stdout if path == "-" else open(path, "w", encoding="utf-8")
|
||||
|
||||
|
||||
def stamp_command(input_path: str, output_path: str,
|
||||
session_id: Optional[str], model: Optional[str]) -> Dict[str, Any]:
|
||||
"""Tag all pairs in a file with provenance metadata."""
|
||||
tagged = 0
|
||||
skipped = 0
|
||||
errors = 0
|
||||
|
||||
source = _open_input(input_path)
|
||||
out = _open_output(output_path)
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
# Skip if already tagged with same model+session
|
||||
existing = pair.get("_provenance", {})
|
||||
if (existing.get("source_model") == model
|
||||
and existing.get("source_session_id") == session_id):
|
||||
skipped += 1
|
||||
out.write(line + "\n")
|
||||
continue
|
||||
|
||||
pair = tag_pair(pair, session_id=session_id, model=model)
|
||||
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||||
tagged += 1
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
# Never close stdout — it breaks downstream piping
|
||||
|
||||
return {"tagged": tagged, "skipped": skipped, "errors": errors}
|
||||
|
||||
|
||||
def filter_pairs(input_path: str, output_path: str,
|
||||
include_models: Optional[List[str]] = None,
|
||||
exclude_models: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""Filter pairs by provenance metadata."""
|
||||
kept = []
|
||||
removed = []
|
||||
errors = 0
|
||||
|
||||
source = _open_input(input_path)
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
prov = pair.get("_provenance", {})
|
||||
model = prov.get("source_model", "unknown")
|
||||
|
||||
should_keep = True
|
||||
|
||||
if include_models:
|
||||
should_keep = should_keep and model in include_models
|
||||
|
||||
if exclude_models:
|
||||
should_keep = should_keep and model not in exclude_models
|
||||
|
||||
if should_keep:
|
||||
kept.append(pair)
|
||||
else:
|
||||
removed.append(pair)
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
|
||||
# Write output
|
||||
if output_path:
|
||||
out = _open_output(output_path)
|
||||
try:
|
||||
for pair in kept:
|
||||
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||||
finally:
|
||||
if out is not sys.stdout:
|
||||
out.close()
|
||||
|
||||
return {
|
||||
"total": len(kept) + len(removed),
|
||||
"kept": len(kept),
|
||||
"filtered_out": len(removed),
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
def report(input_path: str) -> Dict[str, Any]:
|
||||
"""Report pair counts by source model and session."""
|
||||
model_counts: Counter = Counter()
|
||||
session_counts: Counter = Counter()
|
||||
tagged = 0
|
||||
untagged = 0
|
||||
total = 0
|
||||
errors = 0
|
||||
|
||||
source = _open_input(input_path)
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
total += 1
|
||||
prov = pair.get("_provenance", {})
|
||||
|
||||
if prov:
|
||||
tagged += 1
|
||||
model = prov.get("source_model", "unknown")
|
||||
session = prov.get("source_session_id", "unknown")
|
||||
model_counts[model] += 1
|
||||
session_counts[session] += 1
|
||||
else:
|
||||
untagged += 1
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"tagged": tagged,
|
||||
"untagged": untagged,
|
||||
"tag_rate": round(tagged / max(total, 1) * 100, 1),
|
||||
"by_model": dict(model_counts.most_common(20)),
|
||||
"by_session": dict(session_counts.most_common(10)),
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Training pair provenance tracking")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# tag subcommand
|
||||
tag_p = sub.add_parser("tag", help="Tag pairs with provenance metadata")
|
||||
tag_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
tag_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
|
||||
tag_p.add_argument("--session", help="Source session ID")
|
||||
tag_p.add_argument("--model", help="Source model name")
|
||||
|
||||
# filter subcommand
|
||||
filt_p = sub.add_parser("filter", help="Filter pairs by provenance")
|
||||
filt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
filt_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
|
||||
filt_p.add_argument("--include-model", action="append", help="Only include these models")
|
||||
filt_p.add_argument("--exclude-model", action="append", help="Exclude these models")
|
||||
|
||||
# report subcommand
|
||||
rpt_p = sub.add_parser("report", help="Report provenance statistics")
|
||||
rpt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "tag":
|
||||
result = stamp_command(args.input, args.output, args.session, args.model)
|
||||
print(f"Tagged: {result['tagged']} Skipped: {result['skipped']} Errors: {result['errors']}", file=sys.stderr)
|
||||
|
||||
elif args.command == "filter":
|
||||
result = filter_pairs(
|
||||
args.input, args.output,
|
||||
include_models=args.include_model,
|
||||
exclude_models=args.exclude_model,
|
||||
)
|
||||
print(f"Total: {result['total']} Kept: {result['kept']} Filtered: {result['filtered_out']}", file=sys.stderr)
|
||||
|
||||
elif args.command == "report":
|
||||
result = report(args.input)
|
||||
print(f"Training Pair Provenance Report", file=sys.stderr)
|
||||
print(f"{'='*40}", file=sys.stderr)
|
||||
print(f"Total pairs: {result['total']}", file=sys.stderr)
|
||||
print(f"Tagged: {result['tagged']} ({result['tag_rate']}%)", file=sys.stderr)
|
||||
print(f"Untagged: {result['untagged']}", file=sys.stderr)
|
||||
|
||||
if result["by_model"]:
|
||||
print(f"\nBy source model:", file=sys.stderr)
|
||||
for model, count in result["by_model"].items():
|
||||
print(f" {model}: {count}", file=sys.stderr)
|
||||
|
||||
if result["by_session"]:
|
||||
print(f"\nBy source session (top 10):", file=sys.stderr)
|
||||
for session, count in result["by_session"].items():
|
||||
session_short = session[:12] + "..." if len(session) > 12 else session
|
||||
print(f" {session_short}: {count}", file=sys.stderr)
|
||||
|
||||
# JSON output to stdout
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
363
tests/test_training_provenance.py
Normal file
363
tests/test_training_provenance.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Tests for training pair provenance tracking."""
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
|
||||
SCRIPT = os.path.join(os.path.dirname(__file__), "..", "scripts", "training_provenance.py")
|
||||
|
||||
|
||||
def _run(args, stdin=None):
|
||||
"""Run training_provenance.py and return (stdout, stderr, returncode)."""
|
||||
result = subprocess.run(
|
||||
[sys.executable, SCRIPT] + args,
|
||||
capture_output=True, text=True,
|
||||
input=stdin,
|
||||
)
|
||||
return result.stdout, result.stderr, result.returncode
|
||||
|
||||
|
||||
def _make_pairs(count=3, model="nous/hermes-3", session="sess-123"):
|
||||
"""Generate test JSONL pairs."""
|
||||
lines = []
|
||||
for i in range(count):
|
||||
lines.append(json.dumps({"terse": f"q{i}", "rich": f"a{i}", "domain": "test"}))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ── tag command ──────────────────────────────────────────────────
|
||||
|
||||
class TestTagCommand:
|
||||
def test_tag_adds_provenance_to_each_pair(self):
|
||||
pairs = _make_pairs(3)
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["tag", f.name, "-o", out_path,
|
||||
"--session", "sess-abc", "--model", "nous/hermes-3"])
|
||||
assert rc == 0
|
||||
|
||||
with open(out_path) as f:
|
||||
lines = [json.loads(l) for l in f if l.strip()]
|
||||
|
||||
assert len(lines) == 3
|
||||
for pair in lines:
|
||||
prov = pair["_provenance"]
|
||||
assert prov["source_session_id"] == "sess-abc"
|
||||
assert prov["source_model"] == "nous/hermes-3"
|
||||
assert "source_timestamp" in prov
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_preserves_existing_pair_data(self):
|
||||
pairs = '{"terse": "hello", "rich": "world", "domain": "greeting"}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_run(["tag", f.name, "-o", out_path, "--model", "m1"])
|
||||
with open(out_path) as f:
|
||||
pair = json.loads(f.readline())
|
||||
assert pair["terse"] == "hello"
|
||||
assert pair["rich"] == "world"
|
||||
assert pair["domain"] == "greeting"
|
||||
assert pair["_provenance"]["source_model"] == "m1"
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_skips_already_tagged_same_provenance(self):
|
||||
pair = json.dumps({
|
||||
"terse": "q", "rich": "a",
|
||||
"_provenance": {"source_model": "m1", "source_session_id": "s1"}
|
||||
})
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pair + "\n")
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path,
|
||||
"--session", "s1", "--model", "m1"])
|
||||
assert rc == 0
|
||||
assert "Skipped: 1" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_overwrites_different_provenance(self):
|
||||
pair = json.dumps({
|
||||
"terse": "q", "rich": "a",
|
||||
"_provenance": {"source_model": "old-model", "source_session_id": "old-sess"}
|
||||
})
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pair + "\n")
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path,
|
||||
"--session", "new-sess", "--model", "new-model"])
|
||||
assert rc == 0
|
||||
assert "Tagged: 1" in err
|
||||
|
||||
with open(out_path) as f:
|
||||
tagged = json.loads(f.readline())
|
||||
assert tagged["_provenance"]["source_model"] == "new-model"
|
||||
assert tagged["_provenance"]["source_session_id"] == "new-sess"
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_skips_blank_lines(self):
|
||||
pairs = '{"t":"a","r":"b"}\n\n{"t":"c","r":"d"}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path, "--model", "m1"])
|
||||
assert rc == 0
|
||||
assert "Tagged: 2" in err
|
||||
assert "Errors: 0" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_tag_counts_malformed_lines_as_errors(self):
|
||||
pairs = '{"t":"a"}\nNOT_JSON\n{"t":"b"}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".tagged"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["tag", f.name, "-o", out_path, "--model", "m1"])
|
||||
assert rc == 0
|
||||
assert "Errors: 1" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
|
||||
# ── filter command ───────────────────────────────────────────────
|
||||
|
||||
class TestFilterCommand:
|
||||
def test_filter_exclude_model(self):
|
||||
lines = []
|
||||
for model in ["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": model}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".filtered"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["filter", f.name, "-o", out_path,
|
||||
"--exclude-model", "anthropic"])
|
||||
assert rc == 0
|
||||
assert "Kept: 2" in err
|
||||
assert "Filtered: 1" in err
|
||||
|
||||
with open(out_path) as f:
|
||||
kept = [json.loads(l) for l in f if l.strip()]
|
||||
models = [p["_provenance"]["source_model"] for p in kept]
|
||||
assert "anthropic/claude" not in models
|
||||
assert "nous/hermes-3" in models
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_filter_include_model(self):
|
||||
lines = []
|
||||
for model in ["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": model}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".filtered"
|
||||
|
||||
try:
|
||||
_, err, rc = _run(["filter", f.name, "-o", out_path,
|
||||
"--include-model", "nous/hermes-3"])
|
||||
assert rc == 0
|
||||
assert "Kept: 1" in err
|
||||
|
||||
with open(out_path) as f:
|
||||
kept = [json.loads(l) for l in f if l.strip()]
|
||||
assert len(kept) == 1
|
||||
assert kept[0]["_provenance"]["source_model"] == "nous/hermes-3"
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
def test_filter_untreated_pairs_have_unknown_model(self):
|
||||
pairs = '{"t":"q","r":"a"}\n' # no _provenance
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
out_path = f.name + ".filtered"
|
||||
|
||||
try:
|
||||
# Exclude "unknown" — should filter out unprovenanced pair
|
||||
_, err, rc = _run(["filter", f.name, "-o", out_path,
|
||||
"--exclude-model", "unknown"])
|
||||
assert rc == 0
|
||||
assert "Kept: 0" in err
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
if os.path.exists(out_path):
|
||||
os.unlink(out_path)
|
||||
|
||||
|
||||
# ── report command ───────────────────────────────────────────────
|
||||
|
||||
class TestReportCommand:
|
||||
def test_report_counts_by_model(self):
|
||||
lines = []
|
||||
for model in ["nous/hermes-3", "nous/hermes-3", "anthropic/claude"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": model, "source_session_id": "s1"}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 3
|
||||
assert result["tagged"] == 3
|
||||
assert result["untagged"] == 0
|
||||
assert result["tag_rate"] == 100.0
|
||||
assert result["by_model"]["nous/hermes-3"] == 2
|
||||
assert result["by_model"]["anthropic/claude"] == 1
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_report_distinguishes_tagged_vs_untagged(self):
|
||||
pairs = '{"t":"q","r":"a"}\n{"t":"q2","r":"a2","_provenance":{"source_model":"m1"}}\n'
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 2
|
||||
assert result["tagged"] == 1
|
||||
assert result["untagged"] == 1
|
||||
assert result["tag_rate"] == 50.0
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_report_handles_empty_file(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write("")
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, err, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 0
|
||||
assert result["tag_rate"] == 0
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_report_counts_by_session(self):
|
||||
lines = []
|
||||
for sess in ["sess-a", "sess-a", "sess-b"]:
|
||||
lines.append(json.dumps({
|
||||
"t": "q", "r": "a",
|
||||
"_provenance": {"source_model": "m1", "source_session_id": sess}
|
||||
}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(pairs)
|
||||
f.flush()
|
||||
|
||||
try:
|
||||
out, _, rc = _run(["report", f.name])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["by_session"]["sess-a"] == 2
|
||||
assert result["by_session"]["sess-b"] == 1
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
# ── integration ──────────────────────────────────────────────────
|
||||
|
||||
class TestIntegration:
|
||||
def test_tag_then_filter_then_report(self):
|
||||
"""Full pipeline: tag → filter → report."""
|
||||
lines = []
|
||||
for i, model in enumerate(["nous/hermes-3", "anthropic/claude", "openai/gpt-4"]):
|
||||
lines.append(json.dumps({"terse": f"q{i}", "rich": f"a{i}", "domain": "test"}))
|
||||
pairs = "\n".join(lines)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as src:
|
||||
src.write(pairs)
|
||||
src.flush()
|
||||
|
||||
tagged_path = src.name + ".tagged"
|
||||
filtered_path = src.name + ".filtered"
|
||||
|
||||
try:
|
||||
# Step 1: Tag all with session info
|
||||
_, _, rc = _run(["tag", src.name, "-o", tagged_path,
|
||||
"--session", "pipe-1", "--model", "nous/hermes-3"])
|
||||
assert rc == 0
|
||||
|
||||
# Step 2: Filter — exclude "unknown" model (untagged pairs)
|
||||
_, err2, rc = _run(["filter", tagged_path, "-o", filtered_path,
|
||||
"--exclude-model", "unknown"])
|
||||
assert rc == 0
|
||||
assert "Kept: 3" in err2
|
||||
|
||||
# Step 3: Report
|
||||
out, _, rc = _run(["report", filtered_path])
|
||||
assert rc == 0
|
||||
result = json.loads(out)
|
||||
assert result["total"] == 3
|
||||
assert result["tagged"] == 3
|
||||
assert result["tag_rate"] == 100.0
|
||||
finally:
|
||||
for p in [src.name, tagged_path, filtered_path]:
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
Reference in New Issue
Block a user