Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
49296d538e feat: nightly pipeline scheduler — auto-start when inference available (#624)
Some checks failed
Architecture Lint / Linter Tests (pull_request) Successful in 36s
PR Checklist / pr-checklist (pull_request) Failing after 6m56s
Smoke Test / smoke (pull_request) Failing after 34s
Validate Config / YAML Lint (pull_request) Failing after 21s
Validate Config / JSON Validate (pull_request) Successful in 9s
Validate Config / Python Syntax & Import Check (pull_request) Failing after 1m40s
Validate Config / Shell Script Lint (pull_request) Failing after 48s
Validate Config / Cron Syntax Check (pull_request) Successful in 18s
Validate Config / Deploy Script Dry Run (pull_request) Successful in 20s
Validate Config / Playbook Schema Validation (pull_request) Successful in 38s
Architecture Lint / Lint Repository (pull_request) Has been cancelled
Validate Config / Python Test Suite (pull_request) Has been cancelled
Scheduler that auto-starts batch pipelines when inference is available.

Features:
- Checks inference provider availability (local Ollama, RunPod, OpenRouter)
- Priority ordering: playground > training > knowledge > adversary > genome
- Dependency rules (e.g., knowledge_mine waits for training_factory)
- Daily token budget (5M default, configurable)
- Peak-hour pausing (8am-10pm = interactive mode, no pipelines)
- State persistence via ~/.hermes/pipeline_state.json
- One pipeline per cycle to avoid overload

Usage:
  python3 pipeline/nightly_scheduler.py --status
  python3 pipeline/nightly_scheduler.py --check      # dry-run
  python3 pipeline/nightly_scheduler.py              # live

Cron: */30 22-5 * * * pipeline/nightly_scheduler.py

Closes #624
2026-04-15 08:14:00 -04:00
3 changed files with 331 additions and 581 deletions

331
pipeline/nightly_scheduler.py Executable file
View File

@@ -0,0 +1,331 @@
#!/usr/bin/env python3
"""
nightly_scheduler.py — Nightly Pipeline Scheduler
Auto-starts batch pipelines when inference is available, respecting
priority ordering, token budgets, and peak-hour pausing.
Usage:
python3 nightly_scheduler.py # run scheduler
python3 nightly_scheduler.py --check # dry-run: show what would start
python3 nightly_scheduler.py --status # show pipeline status
python3 nightly_scheduler.py --reset # reset daily budget
Crontab:
# Run every 30 minutes during off-peak hours (10pm-6am)
*/30 22-5 * * * cd /path/to/timmy-config && python3 pipeline/nightly_scheduler.py >> ~/.hermes/pipeline-logs/nightly.log 2>&1
"""
import json
import os
import sys
import time
import urllib.request
import urllib.error
from datetime import datetime, timezone, timedelta
from pathlib import Path
# --- Config ---
STATE_FILE = Path.home() / ".hermes" / "pipeline_state.json"
LOG_DIR = Path.home() / ".hermes" / "pipeline-logs"
DAILY_TOKEN_BUDGET = 5_000_000 # 5M tokens per day
PEAK_HOURS = list(range(8, 22)) # 8am-10pm = peak interactive usage
CHECK_INTERVAL = 1800 # 30 minutes
INFERENCE_ENDPOINTS = [
{"name": "local_ollama", "url": "http://localhost:11434/v1/models", "type": "local"},
{"name": "runpod", "url": "https://8lfr3j47a5r3gn-11434.proxy.runpod.net/v1/models", "type": "gpu"},
{"name": "openrouter", "url": "https://openrouter.ai/api/v1/models", "type": "cloud"},
]
# Pipeline priority order (highest first)
PIPELINE_PRIORITY = [
{"name": "playground_factory", "script": "pipeline/playground_factory.py", "priority": 1},
{"name": "training_factory", "script": "pipeline/training_factory.py", "priority": 2},
{"name": "knowledge_mine", "script": "pipeline/knowledge_mine.py", "priority": 3},
{"name": "adversary", "script": "pipeline/adversary_runner.py", "priority": 4},
{"name": "codebase_genome", "script": "pipeline/codebase_genome.py", "priority": 5},
]
# Dependency rules: some pipelines only start after others are running
DEPENDENCY_RULES = {
"playground_factory": [], # no deps, start immediately
"training_factory": [], # no deps, start in parallel
"knowledge_mine": ["training_factory"], # start after training is running
"adversary": ["knowledge_mine"], # start after knowledge is halfway
"codebase_genome": [], # continuous, one repo per night
}
def load_state():
"""Load pipeline state from disk."""
if STATE_FILE.exists():
with open(STATE_FILE) as f:
return json.load(f)
return {
"last_run": None,
"daily_tokens_used": 0,
"budget_reset_date": None,
"pipelines": {},
"active_sessions": [],
}
def save_state(state):
"""Save pipeline state to disk."""
STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(STATE_FILE, "w") as f:
json.dump(state, f, indent=2)
def check_provider(endpoint):
"""Check if an inference provider is available."""
try:
req = urllib.request.Request(endpoint["url"], headers={"Authorization": "Bearer ollama"})
with urllib.request.urlopen(req, timeout=10) as resp:
return resp.status == 200
except Exception:
return False
def get_available_providers():
"""Check all inference endpoints and return available ones."""
available = []
for ep in INFERENCE_ENDPOINTS:
if check_provider(ep):
available.append(ep["name"])
return available
def is_peak_hours():
"""Check if current time is during peak interactive usage."""
now = datetime.now()
return now.hour in PEAK_HOURS
def check_token_budget(state):
"""Check if daily token budget allows starting new work."""
today = datetime.now().strftime("%Y-%m-%d")
if state.get("budget_reset_date") != today:
# New day, reset budget
state["daily_tokens_used"] = 0
state["budget_reset_date"] = today
save_state(state)
return state["daily_tokens_used"] < DAILY_TOKEN_BUDGET
def get_pipeline_status(state, pipeline_name):
"""Get the status of a specific pipeline."""
return state.get("pipelines", {}).get(pipeline_name, {
"status": "not_started",
"last_run": None,
"last_success": None,
"progress": 0,
})
def check_dependencies(state, pipeline_name):
"""Check if pipeline dependencies are satisfied."""
deps = DEPENDENCY_RULES.get(pipeline_name, [])
for dep in deps:
dep_status = get_pipeline_status(state, dep)
if dep_status["status"] not in ("running", "completed"):
return False
return True
def start_pipeline(pipeline, state, dry_run=False):
"""Start a pipeline process."""
name = pipeline["name"]
script = pipeline["script"]
log(f"Starting pipeline: {name}")
if dry_run:
log(f" DRY RUN — would run: python3 {script}")
return True
# Check if script exists
script_path = Path(script)
if not script_path.exists():
log(f" Script not found: {script_path}")
# Update state anyway so we track the attempt
state["pipelines"][name] = {
"status": "script_missing",
"last_run": datetime.now(timezone.utc).isoformat(),
"progress": 0,
}
save_state(state)
return False
# Run the pipeline script
import subprocess
log_dir = LOG_DIR / name
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
try:
proc = subprocess.Popen(
["python3", str(script_path)],
stdout=open(log_file, "w"),
stderr=subprocess.STDOUT,
cwd=str(Path(script).parent.parent),
)
state["pipelines"][name] = {
"status": "running",
"pid": proc.pid,
"last_run": datetime.now(timezone.utc).isoformat(),
"log_file": str(log_file),
"progress": 0,
}
save_state(state)
log(f" Started PID {proc.pid}, log: {log_file}")
return True
except Exception as e:
log(f" Failed to start: {e}")
state["pipelines"][name] = {
"status": "failed",
"last_run": datetime.now(timezone.utc).isoformat(),
"error": str(e),
}
save_state(state)
return False
def check_running_pipelines(state):
"""Check status of running pipelines and update state."""
import subprocess
for name, info in state.get("pipelines", {}).items():
if info.get("status") == "running":
pid = info.get("pid")
if pid:
try:
os.kill(pid, 0) # Check if process exists
except ProcessLookupError:
# Process finished
info["status"] = "completed"
info["completed_at"] = datetime.now(timezone.utc).isoformat()
log(f"Pipeline {name} completed (PID {pid} exited)")
save_state(state)
def run_scheduler(dry_run=False, check_only=False):
"""Main scheduler loop."""
state = load_state()
log("=" * 50)
log(f"Pipeline Scheduler — {datetime.now().isoformat()}")
log(f"Mode: {'CHECK' if check_only else 'DRY RUN' if dry_run else 'LIVE'}")
# Check peak hours
if is_peak_hours():
log("Peak hours detected. Pausing pipeline starts.")
log("Pipelines will resume at 10pm.")
return
# Check token budget
if not check_token_budget(state):
log(f"Daily token budget exhausted ({state['daily_tokens_used']}/{DAILY_TOKEN_BUDGET})")
return
log(f"Token budget: {state['daily_tokens_used']}/{DAILY_TOKEN_BUDGET}")
# Check providers
providers = get_available_providers()
if not providers:
log("No inference providers available. Skipping.")
return
log(f"Available providers: {', '.join(providers)}")
# Check running pipelines
check_running_pipelines(state)
# Find next pipeline to start
started = 0
for pipeline in sorted(PIPELINE_PRIORITY, key=lambda p: p["priority"]):
name = pipeline["name"]
status = get_pipeline_status(state, name)
# Skip if already running or completed
if status["status"] in ("running", "completed"):
log(f" {name}: {status['status']} (skipping)")
continue
# Check dependencies
if not check_dependencies(state, name):
deps = DEPENDENCY_RULES.get(name, [])
log(f" {name}: waiting for dependencies: {deps}")
continue
# Start the pipeline
if check_only:
log(f" {name}: READY to start (priority {pipeline['priority']})")
else:
if start_pipeline(pipeline, state, dry_run):
started += 1
# Only start one pipeline per run to avoid overload
if started >= 1:
log("Started 1 pipeline. Will check again next cycle.")
break
if started == 0 and not check_only:
log("No pipelines to start. All are running, completed, or blocked.")
log("=" * 50)
def show_status():
"""Show current pipeline status."""
state = load_state()
print(f"\nPipeline Status — {datetime.now().strftime('%Y-%m-%d %H:%M')}")
print(f"Token budget: {state.get('daily_tokens_used', 0)}/{DAILY_TOKEN_BUDGET}")
print(f"Last run: {state.get('last_run', 'never')}")
print()
for pipeline in sorted(PIPELINE_PRIORITY, key=lambda p: p["priority"]):
name = pipeline["name"]
status = get_pipeline_status(state, name)
st = status["status"]
icon = {"running": "", "completed": "", "failed": "", "not_started": "", "script_missing": "?"}.get(st, "?")
print(f" {icon} {name:25} {st:15} last={(status.get('last_run') or 'never')[:19]}")
def reset_budget():
"""Reset daily token budget."""
state = load_state()
state["daily_tokens_used"] = 0
state["budget_reset_date"] = datetime.now().strftime("%Y-%m-%d")
save_state(state)
print("Budget reset.")
def log(msg):
"""Log to stdout and file."""
timestamp = datetime.now().strftime("%H:%M:%S")
line = f"[{timestamp}] {msg}"
print(line)
LOG_DIR.mkdir(parents=True, exist_ok=True)
log_file = LOG_DIR / "nightly.log"
with open(log_file, "a") as f:
f.write(line + "\n")
def main():
import argparse
parser = argparse.ArgumentParser(description="Nightly Pipeline Scheduler")
parser.add_argument("--check", action="store_true", help="Dry-run: show what would start")
parser.add_argument("--status", action="store_true", help="Show pipeline status")
parser.add_argument("--reset", action="store_true", help="Reset daily token budget")
parser.add_argument("--dry-run", action="store_true", help="Dry-run mode")
args = parser.parse_args()
if args.status:
show_status()
elif args.reset:
reset_budget()
else:
run_scheduler(dry_run=args.dry_run or args.check, check_only=args.check)
if __name__ == "__main__":
main()

View File

@@ -1,389 +0,0 @@
#!/usr/bin/env python3
"""
Training Data Quality Filter (#687)
Scores and removes low-quality training pairs from JSONL files.
Supports: ShareGPT format, preference pairs, generic JSONL.
Usage:
python3 scripts/filter_training_data.py <input.jsonl> [--output filtered.jsonl]
python3 scripts/filter_training_data.py training/data/preference_pairs.jsonl
python3 scripts/filter_training_data.py training/data/curated_dataset.jsonl --threshold 0.3
"""
import argparse
import ast
import json
import os
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# ============================================================
# QUALITY SCORING
# ============================================================
# Generic filler phrases that indicate low-quality responses
FILLER_PHRASES = [
"as an ai", "i'm an ai", "as a language model", "i don't have personal",
"i cannot", "i can't", "it's important to note", "please note that",
"in conclusion", "to summarize", "in summary", "hope this helps",
"let me know if", "feel free to", "i'd be happy to", "certainly!",
"of course!", "absolutely!", "great question!", "that's a great",
"i understand your", "i appreciate your", "thank you for asking",
"it depends", "there are many ways", "various factors",
]
# Vague/generic short responses
VAGUE_RESPONSES = [
"ok", "okay", "sure", "yes", "no", "maybe", "idk", "i don't know",
"thanks", "thank you", "got it", "understood", "right", "correct",
"hello", "hi", "hey", "goodbye", "bye",
]
CODE_BLOCK_PATTERN = re.compile(r"```(?:\w+)?\n(.+?)```", re.DOTALL)
INLINE_CODE_PATTERN = re.compile(r"`([^`]+)`")
def detect_format(record: dict) -> str:
"""Detect the training data format of a record."""
if "conversations" in record:
return "sharegpt"
if "prompt" in record and "chosen" in record:
return "preference"
if "scene" in record and "lyric_line" in record:
return "scene"
if "terse" in record and "rich" in record:
return "pairs"
return "generic"
def extract_text_fields(record: dict, fmt: str) -> Tuple[str, str]:
"""Extract (input_text, output_text) from a record based on format."""
if fmt == "sharegpt":
convs = record.get("conversations", [])
human_msgs = [c["value"] for c in convs if c.get("from") == "human"]
gpt_msgs = [c["value"] for c in convs if c.get("from") == "gpt"]
input_text = human_msgs[-1] if human_msgs else ""
output_text = gpt_msgs[-1] if gpt_msgs else ""
return input_text, output_text
elif fmt == "preference":
return record.get("prompt", ""), record.get("chosen", "")
elif fmt == "scene":
return record.get("lyric_line", ""), record.get("scene", {}).get("description", "")
elif fmt == "pairs":
return record.get("terse", ""), record.get("rich", "")
else:
# Generic: try common field names
input_text = record.get("input", record.get("prompt", record.get("question", "")))
output_text = record.get("output", record.get("response", record.get("answer", "")))
return str(input_text), str(output_text)
def score_specificity(text: str) -> float:
"""Score 0-1 how specific/detailed a response is vs generic filler."""
if not text or not text.strip():
return 0.0
text_lower = text.lower().strip()
score = 0.5 # baseline
# Penalize filler phrases
filler_count = sum(1 for phrase in FILLER_PHRASES if phrase in text_lower)
score -= filler_count * 0.08
# Penalize very short responses
word_count = len(text.split())
if word_count < 5:
score -= 0.3
elif word_count < 10:
score -= 0.15
elif word_count > 30:
score += 0.1 # longer responses tend to be more detailed
# Penalize vague single-word responses
if text_lower.strip() in VAGUE_RESPONSES:
score -= 0.4
# Reward specificity indicators
specificity_markers = [
r"\d+", # numbers
r"```", # code blocks
r"https?://", # URLs
r"\$\{", r"\w+\.\w+", # code-like patterns
r"(?:specifically|exactly|precisely|in particular)",
r"(?:step \d|first,|second,|third,|finally,)",
]
for pattern in specificity_markers:
if re.search(pattern, text):
score += 0.05
# Reward code presence
if "```" in text:
score += 0.15
return max(0.0, min(1.0, score))
def score_length_ratio(input_text: str, output_text: str) -> float:
"""Score 0-1 based on reasonable length ratio between input and output."""
in_len = len(input_text.split())
out_len = len(output_text.split())
if in_len == 0 and out_len == 0:
return 0.0
if out_len == 0:
return 0.0
# Ideal ratio: output 0.5x to 10x input length
# Too short output for long input = bad
# Too long output for short input = acceptable (detailed answer)
if in_len > 0:
ratio = out_len / in_len
else:
ratio = out_len / 10 # normalize when no input
if ratio < 0.05:
return 0.1 # output way too short
elif ratio < 0.2:
return 0.3
elif ratio < 0.5:
return 0.6
elif ratio <= 15:
return 1.0 # sweet spot
elif ratio <= 50:
return 0.8
else:
return 0.5 # extremely long output, maybe noise
def score_code_correctness(text: str) -> float:
"""Score 0-1 for code correctness if code blocks are present."""
code_blocks = CODE_BLOCK_PATTERN.findall(text)
if not code_blocks:
return 1.0 # no code, not penalized
total = len(code_blocks)
valid = 0
for code in code_blocks:
# Try Python syntax check
try:
ast.parse(code)
valid += 1
continue
except SyntaxError:
pass
# Try JavaScript basic check (balanced braces/parens)
if _check_brackets_balanced(code):
valid += 0.8
continue
# JSON check
try:
json.loads(code)
valid += 1
continue
except (json.JSONDecodeError, ValueError):
pass
# Shell/YAML: just check it's not empty garbage
if len(code.strip()) > 10 and "\n" in code:
valid += 0.5
return valid / total if total > 0 else 1.0
def _check_brackets_balanced(code: str) -> bool:
"""Check if brackets are balanced in code."""
stack = []
pairs = {"(": ")", "[": "]", "{": "}"}
for ch in code:
if ch in pairs:
stack.append(pairs[ch])
elif ch in pairs.values():
if not stack or stack[-1] != ch:
return False
stack.pop()
return len(stack) == 0
def score_record(record: dict, fmt: str) -> Dict[str, float]:
"""Score a single training record. Returns dict of component scores."""
input_text, output_text = extract_text_fields(record, fmt)
specificity = score_specificity(output_text)
length_ratio = score_length_ratio(input_text, output_text)
code_correctness = score_code_correctness(output_text)
# Weighted composite
composite = (
specificity * 0.45 +
length_ratio * 0.25 +
code_correctness * 0.30
)
return {
"specificity": round(specificity, 3),
"length_ratio": round(length_ratio, 3),
"code_correctness": round(code_correctness, 3),
"composite": round(composite, 3),
}
# ============================================================
# FILTERING
# ============================================================
def filter_jsonl(
input_path: str,
output_path: Optional[str] = None,
threshold: float = 0.3,
dry_run: bool = False,
verbose: bool = False,
) -> Dict[str, Any]:
"""Filter a JSONL file, removing low-quality records."""
if output_path is None:
stem = Path(input_path).stem
output_path = str(Path(input_path).parent / f"{stem}_filtered.jsonl")
records = []
with open(input_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
line = line.strip()
if not line:
continue
try:
records.append(json.loads(line))
except json.JSONDecodeError as e:
print(f" [WARN] Line {i+1}: invalid JSON, skipping: {e}", file=sys.stderr)
if not records:
return {"error": "No valid records found", "total": 0}
# Detect format from first record
fmt = detect_format(records[0])
print(f" Detected format: {fmt}")
print(f" Total records: {len(records)}")
# Score all records
scored = []
for i, record in enumerate(records):
scores = score_record(record, fmt)
scored.append((record, scores, i))
# Sort by composite score
scored.sort(key=lambda x: x[1]["composite"])
# Filter
kept = [(r, s, i) for r, s, i in scored if s["composite"] >= threshold]
removed = [(r, s, i) for r, s, i in scored if s["composite"] < threshold]
# Report
report = {
"input_file": input_path,
"output_file": output_path,
"format": fmt,
"total_records": len(records),
"kept": len(kept),
"removed": len(removed),
"threshold": threshold,
"removal_rate": f"{len(removed) / len(records) * 100:.1f}%",
"score_distribution": {
"min": scored[0][1]["composite"] if scored else 0,
"max": scored[-1][1]["composite"] if scored else 0,
"median": scored[len(scored)//2][1]["composite"] if scored else 0,
"mean": round(sum(s["composite"] for _, s, _ in scored) / len(scored), 3) if scored else 0,
},
"removed_score_breakdown": {
"specificity_below_0.3": sum(1 for _, s, _ in removed if s["specificity"] < 0.3),
"length_ratio_below_0.3": sum(1 for _, s, _ in removed if s["length_ratio"] < 0.3),
"code_correctness_below_0.5": sum(1 for _, s, _ in removed if s["code_correctness"] < 0.5),
},
}
# Show worst offenders if verbose
if verbose and removed:
print(f"\n Worst 5 records (by composite score):")
for r, s, i in removed[:5]:
_, output_text = extract_text_fields(r, fmt)
preview = output_text[:80].replace("\n", " ") if output_text else "(empty)"
print(f" [{s['composite']:.3f}] {preview}...")
# Write output (unless dry run)
if not dry_run:
# Preserve original order, only keeping filtered records
kept_indices = {i for _, _, i in kept}
with open(output_path, "w", encoding="utf-8") as f:
for i, record in enumerate(records):
if i in kept_indices:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
print(f"\n Written: {output_path}")
return report
# ============================================================
# CLI
# ============================================================
def main():
parser = argparse.ArgumentParser(
description="Training data quality filter — remove low-quality pairs (#687)"
)
parser.add_argument("input", help="Input JSONL file path")
parser.add_argument("--output", "-o", help="Output file path (default: <input>_filtered.jsonl)")
parser.add_argument("--threshold", "-t", type=float, default=0.3,
help="Minimum composite score to keep (default: 0.3)")
parser.add_argument("--dry-run", "-n", action="store_true",
help="Score only, don't write output")
parser.add_argument("--verbose", "-v", action="store_true",
help="Show worst offenders")
parser.add_argument("--report-json", "-j", help="Write report as JSON to file")
args = parser.parse_args()
if not os.path.exists(args.input):
print(f"Error: {args.input} not found", file=sys.stderr)
sys.exit(1)
print(f"Filtering: {args.input}")
print(f"Threshold: {args.threshold}")
print()
report = filter_jsonl(
args.input,
output_path=args.output,
threshold=args.threshold,
dry_run=args.dry_run,
verbose=args.verbose,
)
print(f"\n{'=' * 50}")
print(f" RESULTS")
print(f"{'=' * 50}")
print(f" Format: {report['format']}")
print(f" Total: {report['total_records']}")
print(f" Kept: {report['kept']}")
print(f" Removed: {report['removed']} ({report['removal_rate']})")
print(f" Threshold: {report['threshold']}")
print(f" Score range: {report['score_distribution']['min']:.3f} - {report['score_distribution']['max']:.3f}")
print(f" Mean score: {report['score_distribution']['mean']:.3f}")
if args.report_json:
with open(args.report_json, "w") as f:
json.dump(report, f, indent=2)
print(f"\n Report saved: {args.report_json}")
if __name__ == "__main__":
main()

View File

@@ -1,192 +0,0 @@
#!/usr/bin/env python3
"""
Tests for training data quality filter (#687).
"""
import json
import os
import tempfile
import unittest
# Import from the script
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts"))
from filter_training_data import (
detect_format,
extract_text_fields,
score_specificity,
score_length_ratio,
score_code_correctness,
score_record,
filter_jsonl,
FILLER_PHRASES,
VAGUE_RESPONSES,
)
class TestFormatDetection(unittest.TestCase):
def test_sharegpt_format(self):
record = {"conversations": [{"from": "human", "value": "hi"}]}
self.assertEqual(detect_format(record), "sharegpt")
def test_preference_format(self):
record = {"prompt": "do X", "chosen": "done", "rejected": "no"}
self.assertEqual(detect_format(record), "preference")
def test_scene_format(self):
record = {"lyric_line": "test", "scene": {"description": "desc"}}
self.assertEqual(detect_format(record), "scene")
def test_pairs_format(self):
record = {"terse": "short", "rich": "detailed"}
self.assertEqual(detect_format(record), "pairs")
def test_generic_format(self):
record = {"input": "q", "output": "a"}
self.assertEqual(detect_format(record), "generic")
class TestExtractTextFields(unittest.TestCase):
def test_sharegpt_extraction(self):
record = {
"conversations": [
{"from": "system", "value": "system prompt"},
{"from": "human", "value": "hello"},
{"from": "gpt", "value": "hi there"},
]
}
inp, out = extract_text_fields(record, "sharegpt")
self.assertEqual(inp, "hello")
self.assertEqual(out, "hi there")
def test_preference_extraction(self):
record = {"prompt": "question", "chosen": "good answer"}
inp, out = extract_text_fields(record, "preference")
self.assertEqual(inp, "question")
self.assertEqual(out, "good answer")
class TestSpecificityScoring(unittest.TestCase):
def test_empty_text(self):
self.assertEqual(score_specificity(""), 0.0)
def test_filler_heavy(self):
text = "As an AI, I cannot provide that. It's important to note that I'm an AI."
score = score_specificity(text)
self.assertLess(score, 0.3)
def test_vague_response(self):
score = score_specificity("ok")
self.assertLess(score, 0.2)
def test_specific_response(self):
text = "Here are the steps:\n1. First, install Python 3.12\n2. Run `pip install numpy`\n3. Execute main.py"
score = score_specificity(text)
self.assertGreater(score, 0.5)
def test_code_response(self):
text = "Use this:\n```python\ndef hello():\n print('world')\n```"
score = score_specificity(text)
self.assertGreater(score, 0.6)
class TestLengthRatio(unittest.TestCase):
def test_both_empty(self):
self.assertEqual(score_length_ratio("", ""), 0.0)
def test_empty_output(self):
self.assertEqual(score_length_ratio("hello world", ""), 0.0)
def test_good_ratio(self):
score = score_length_ratio("short question", "This is a reasonable length answer that addresses the question.")
self.assertGreater(score, 0.7)
def test_too_short_output(self):
score = score_length_ratio("This is a very long question with many words that expects a detailed answer", "ok")
self.assertLess(score, 0.5)
class TestCodeCorrectness(unittest.TestCase):
def test_no_code(self):
self.assertEqual(score_code_correctness("plain text"), 1.0)
def test_valid_python(self):
text = "```python\ndef foo():\n return 42\n```"
self.assertEqual(score_code_correctness(text), 1.0)
def test_invalid_python(self):
text = "```python\ndef foo(\n return 42\n```"
score = score_code_correctness(text)
self.assertLess(score, 1.0)
def test_valid_json(self):
text = "```json\n{\"key\": \"value\"}\n```"
self.assertEqual(score_code_correctness(text), 1.0)
class TestFilterJsonl(unittest.TestCase):
def _write_temp_jsonl(self, records):
f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
for r in records:
f.write(json.dumps(r) + "\n")
f.close()
return f.name
def test_filter_removes_low_quality(self):
records = [
{"conversations": [
{"from": "human", "value": "How do I sort a list in Python?"},
{"from": "gpt", "value": "Use `sorted()` or `list.sort()`.\n```python\nnums = [3,1,2]\nnums.sort()\nprint(nums) # [1, 2, 3]\n```"},
]},
{"conversations": [
{"from": "human", "value": "What is Python?"},
{"from": "gpt", "value": "ok"},
]},
{"conversations": [
{"from": "human", "value": "Tell me about databases."},
{"from": "gpt", "value": "As an AI, I cannot. It's important to note."},
]},
]
path = self._write_temp_jsonl(records)
try:
report = filter_jsonl(path, threshold=0.3)
self.assertEqual(report["total_records"], 3)
self.assertGreater(report["kept"], 0)
self.assertGreater(report["removed"], 0)
self.assertEqual(report["format"], "sharegpt")
finally:
os.unlink(path)
if os.path.exists(report.get("output_file", "")):
os.unlink(report["output_file"])
def test_dry_run_no_output(self):
records = [
{"prompt": "test", "chosen": "good detailed answer with code: `print(1)`", "rejected": "no"},
]
path = self._write_temp_jsonl(records)
try:
out_path = path.replace(".jsonl", "_filtered.jsonl")
report = filter_jsonl(path, threshold=0.3, dry_run=True)
self.assertFalse(os.path.exists(out_path))
self.assertEqual(report["total_records"], 1)
finally:
os.unlink(path)
def test_preference_format(self):
records = [
{"prompt": "Write a function", "chosen": "```python\ndef f(): pass\n```", "rejected": ""},
{"prompt": "Hi", "chosen": "ok", "rejected": "no"},
]
path = self._write_temp_jsonl(records)
try:
report = filter_jsonl(path, threshold=0.3)
self.assertEqual(report["format"], "preference")
self.assertEqual(report["total_records"], 2)
finally:
os.unlink(path)
if os.path.exists(report.get("output_file", "")):
os.unlink(report["output_file"])
if __name__ == "__main__":
unittest.main()